diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h index fe3f92da400f8a..94c8fa092f45e6 100644 --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -569,6 +569,10 @@ class VPIntrinsic : public IntrinsicInst { /// The llvm.vp.* intrinsics for this instruction Opcode static Intrinsic::ID getForOpcode(unsigned OC); + /// The llvm.vp.* intrinsics for this intrinsic ID \p Id. Return \p Id if it + /// is already a VP intrinsic. + static Intrinsic::ID getForIntrinsic(Intrinsic::ID Id); + // Whether \p ID is a VP intrinsic ID. static bool isVPIntrinsic(Intrinsic::ID); diff --git a/llvm/include/llvm/IR/VectorBuilder.h b/llvm/include/llvm/IR/VectorBuilder.h index 6af7f6075551dc..dbb9f4c7336d5e 100644 --- a/llvm/include/llvm/IR/VectorBuilder.h +++ b/llvm/include/llvm/IR/VectorBuilder.h @@ -15,7 +15,6 @@ #ifndef LLVM_IR_VECTORBUILDER_H #define LLVM_IR_VECTORBUILDER_H -#include #include #include #include @@ -100,11 +99,11 @@ class VectorBuilder { const Twine &Name = Twine()); /// Emit a VP reduction intrinsic call for recurrence kind. - /// \param Kind The kind of recurrence + /// \param RdxID The intrinsic ID of llvm.vector.reduce.* /// \param ValTy The type of operand which the reduction operation is /// performed. /// \param VecOpArray The operand list. - Value *createSimpleTargetReduction(RecurKind Kind, Type *ValTy, + Value *createSimpleTargetReduction(Intrinsic::ID RdxID, Type *ValTy, ArrayRef VecOpArray, const Twine &Name = Twine()); }; diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h index b01a447f3c28b1..56880bd4822c75 100644 --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -359,6 +359,10 @@ bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, SinkAndHoistLICMFlags &LICMFlags, OptimizationRemarkEmitter *ORE = nullptr); +/// Returns the llvm.vector.reduce intrinsic that corresponds to the recurrence +/// kind. +constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK); + /// Returns the arithmetic instruction opcode used when expanding a reduction. unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID); diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp index 64a14da55b15e3..db3b0196f66fd6 100644 --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -599,6 +599,25 @@ Intrinsic::ID VPIntrinsic::getForOpcode(unsigned IROPC) { return Intrinsic::not_intrinsic; } +constexpr static Intrinsic::ID getForIntrinsic(Intrinsic::ID Id) { + if (::isVPIntrinsic(Id)) + return Id; + + switch (Id) { + default: + break; +#define BEGIN_REGISTER_VP_INTRINSIC(VPID, ...) break; +#define VP_PROPERTY_FUNCTIONAL_INTRINSIC(INTRIN) case Intrinsic::INTRIN: +#define END_REGISTER_VP_INTRINSIC(VPID) return Intrinsic::VPID; +#include "llvm/IR/VPIntrinsics.def" + } + return Intrinsic::not_intrinsic; +} + +Intrinsic::ID VPIntrinsic::getForIntrinsic(Intrinsic::ID Id) { + return ::getForIntrinsic(Id); +} + bool VPIntrinsic::canIgnoreVectorLengthParam() const { using namespace PatternMatch; diff --git a/llvm/lib/IR/VectorBuilder.cpp b/llvm/lib/IR/VectorBuilder.cpp index 5ff30828798950..8dbf25277bf5d2 100644 --- a/llvm/lib/IR/VectorBuilder.cpp +++ b/llvm/lib/IR/VectorBuilder.cpp @@ -60,60 +60,13 @@ Value *VectorBuilder::createVectorInstruction(unsigned Opcode, Type *ReturnTy, return createVectorInstructionImpl(VPID, ReturnTy, InstOpArray, Name); } -Value *VectorBuilder::createSimpleTargetReduction(RecurKind Kind, Type *ValTy, +Value *VectorBuilder::createSimpleTargetReduction(Intrinsic::ID RdxID, + Type *ValTy, ArrayRef InstOpArray, const Twine &Name) { - Intrinsic::ID VPID; - switch (Kind) { - case RecurKind::Add: - VPID = Intrinsic::vp_reduce_add; - break; - case RecurKind::Mul: - VPID = Intrinsic::vp_reduce_mul; - break; - case RecurKind::And: - VPID = Intrinsic::vp_reduce_and; - break; - case RecurKind::Or: - VPID = Intrinsic::vp_reduce_or; - break; - case RecurKind::Xor: - VPID = Intrinsic::vp_reduce_xor; - break; - case RecurKind::FMulAdd: - case RecurKind::FAdd: - VPID = Intrinsic::vp_reduce_fadd; - break; - case RecurKind::FMul: - VPID = Intrinsic::vp_reduce_fmul; - break; - case RecurKind::SMax: - VPID = Intrinsic::vp_reduce_smax; - break; - case RecurKind::SMin: - VPID = Intrinsic::vp_reduce_smin; - break; - case RecurKind::UMax: - VPID = Intrinsic::vp_reduce_umax; - break; - case RecurKind::UMin: - VPID = Intrinsic::vp_reduce_umin; - break; - case RecurKind::FMax: - VPID = Intrinsic::vp_reduce_fmax; - break; - case RecurKind::FMin: - VPID = Intrinsic::vp_reduce_fmin; - break; - case RecurKind::FMaximum: - VPID = Intrinsic::vp_reduce_fmaximum; - break; - case RecurKind::FMinimum: - VPID = Intrinsic::vp_reduce_fminimum; - break; - default: - llvm_unreachable("No VPIntrinsic for this reduction"); - } + auto VPID = VPIntrinsic::getForIntrinsic(RdxID); + assert(VPReductionIntrinsic::isVPReduction(VPID) && + "No VPIntrinsic for this reduction"); return createVectorInstructionImpl(VPID, ValTy, InstOpArray, Name); } diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 4609376a748f9d..0abf6d77496dcd 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -918,6 +918,44 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, return true; } +constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) { + switch (RK) { + default: + llvm_unreachable("Unexpected recurrence kind"); + case RecurKind::Add: + return Intrinsic::vector_reduce_add; + case RecurKind::Mul: + return Intrinsic::vector_reduce_mul; + case RecurKind::And: + return Intrinsic::vector_reduce_and; + case RecurKind::Or: + return Intrinsic::vector_reduce_or; + case RecurKind::Xor: + return Intrinsic::vector_reduce_xor; + case RecurKind::FMulAdd: + case RecurKind::FAdd: + return Intrinsic::vector_reduce_fadd; + case RecurKind::FMul: + return Intrinsic::vector_reduce_fmul; + case RecurKind::SMax: + return Intrinsic::vector_reduce_smax; + case RecurKind::SMin: + return Intrinsic::vector_reduce_smin; + case RecurKind::UMax: + return Intrinsic::vector_reduce_umax; + case RecurKind::UMin: + return Intrinsic::vector_reduce_umin; + case RecurKind::FMax: + return Intrinsic::vector_reduce_fmax; + case RecurKind::FMin: + return Intrinsic::vector_reduce_fmin; + case RecurKind::FMaximum: + return Intrinsic::vector_reduce_fmaximum; + case RecurKind::FMinimum: + return Intrinsic::vector_reduce_fminimum; + } +} + unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) { switch (RdxID) { case Intrinsic::vector_reduce_fadd: @@ -1215,12 +1253,13 @@ Value *llvm::createSimpleTargetReduction(VectorBuilder &VBuilder, Value *Src, RecurKind Kind = Desc.getRecurrenceKind(); assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) && "AnyOf reduction is not supported."); + Intrinsic::ID Id = getReductionIntrinsicID(Kind); auto *SrcTy = cast(Src->getType()); Type *SrcEltTy = SrcTy->getElementType(); Value *Iden = Desc.getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags()); Value *Ops[] = {Iden, Src}; - return VBuilder.createSimpleTargetReduction(Kind, SrcTy, Ops); + return VBuilder.createSimpleTargetReduction(Id, SrcTy, Ops); } Value *llvm::createTargetReduction(IRBuilderBase &B, @@ -1260,9 +1299,10 @@ Value *llvm::createOrderedReduction(VectorBuilder &VBuilder, assert(Src->getType()->isVectorTy() && "Expected a vector type"); assert(!Start->getType()->isVectorTy() && "Expected a scalar type"); + Intrinsic::ID Id = getReductionIntrinsicID(RecurKind::FAdd); auto *SrcTy = cast(Src->getType()); Value *Ops[] = {Start, Src}; - return VBuilder.createSimpleTargetReduction(RecurKind::FAdd, SrcTy, Ops); + return VBuilder.createSimpleTargetReduction(Id, SrcTy, Ops); } void llvm::propagateIRFlags(Value *I, ArrayRef VL, Value *OpValue, diff --git a/llvm/unittests/IR/VPIntrinsicTest.cpp b/llvm/unittests/IR/VPIntrinsicTest.cpp index eab2850ca4e1e8..cf0a10d1f2e959 100644 --- a/llvm/unittests/IR/VPIntrinsicTest.cpp +++ b/llvm/unittests/IR/VPIntrinsicTest.cpp @@ -367,6 +367,59 @@ TEST_F(VPIntrinsicTest, IntrinsicIDRoundTrip) { ASSERT_NE(FullTripCounts, 0u); } +/// Check that going from intrinsic to VP intrinsic and back results in the same +/// intrinsic. +TEST_F(VPIntrinsicTest, IntrinsicToVPRoundTrip) { + bool IsFullTrip = false; + Intrinsic::ID IntrinsicID = Intrinsic::not_intrinsic + 1; + for (; IntrinsicID < Intrinsic::num_intrinsics; IntrinsicID++) { + Intrinsic::ID VPID = VPIntrinsic::getForIntrinsic(IntrinsicID); + // No equivalent VP intrinsic available. + if (VPID == Intrinsic::not_intrinsic) + continue; + + // Return itself if passed intrinsic ID is VP intrinsic. + if (VPIntrinsic::isVPIntrinsic(IntrinsicID)) { + ASSERT_EQ(IntrinsicID, VPID); + continue; + } + + std::optional RoundTripIntrinsicID = + VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID); + // No equivalent non-predicated intrinsic available. + if (!RoundTripIntrinsicID) + continue; + + ASSERT_EQ(*RoundTripIntrinsicID, IntrinsicID); + IsFullTrip = true; + } + ASSERT_TRUE(IsFullTrip); +} + +/// Check that going from VP intrinsic to equivalent non-predicated intrinsic +/// and back results in the same intrinsic. +TEST_F(VPIntrinsicTest, VPToNonPredIntrinsicRoundTrip) { + std::unique_ptr M = createVPDeclarationModule(); + assert(M); + + bool IsFullTrip = false; + for (const auto &VPDecl : *M) { + auto VPID = VPDecl.getIntrinsicID(); + std::optional NonPredID = + VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID); + + // No equivalent non-predicated intrinsic available + if (!NonPredID) + continue; + + Intrinsic::ID RoundTripVPID = VPIntrinsic::getForIntrinsic(*NonPredID); + + ASSERT_EQ(RoundTripVPID, VPID); + IsFullTrip = true; + } + ASSERT_TRUE(IsFullTrip); +} + /// Check that VPIntrinsic::getDeclarationForParams works. TEST_F(VPIntrinsicTest, VPIntrinsicDeclarationForParams) { std::unique_ptr M = createVPDeclarationModule();