Skip to content

Commit

Permalink
[VP] Refactor VectorBuilder to avoid layering violation. NFC (#99276)
Browse files Browse the repository at this point in the history
This patch refactors the handling of reduction to eliminate layering
violations.

* Introduced `getReductionIntrinsicID` in LoopUtils.h for mapping
recurrence kinds to llvm.vector.reduce.* intrinsic IDs.
* Updated `VectorBuilder::createSimpleTargetReduction` to accept
llvm.vector.reduce.* intrinsic directly.
* New function `VPIntrinsic::getForIntrinsic` for mapping intrinsic ID
to the same functional VP intrinsic ID.
  • Loading branch information
Mel-Chen authored Jul 25, 2024
1 parent 693d757 commit 6d12b3f
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 57 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/IntrinsicInst.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,10 @@ class VPIntrinsic : public IntrinsicInst {
/// The llvm.vp.* intrinsics for this instruction Opcode
static Intrinsic::ID getForOpcode(unsigned OC);

/// The llvm.vp.* intrinsics for this intrinsic ID \p Id. Return \p Id if it
/// is already a VP intrinsic.
static Intrinsic::ID getForIntrinsic(Intrinsic::ID Id);

// Whether \p ID is a VP intrinsic ID.
static bool isVPIntrinsic(Intrinsic::ID);

Expand Down
5 changes: 2 additions & 3 deletions llvm/include/llvm/IR/VectorBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#ifndef LLVM_IR_VECTORBUILDER_H
#define LLVM_IR_VECTORBUILDER_H

#include <llvm/Analysis/IVDescriptors.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/InstrTypes.h>
#include <llvm/IR/Instruction.h>
Expand Down Expand Up @@ -100,11 +99,11 @@ class VectorBuilder {
const Twine &Name = Twine());

/// Emit a VP reduction intrinsic call for recurrence kind.
/// \param Kind The kind of recurrence
/// \param RdxID The intrinsic ID of llvm.vector.reduce.*
/// \param ValTy The type of operand which the reduction operation is
/// performed.
/// \param VecOpArray The operand list.
Value *createSimpleTargetReduction(RecurKind Kind, Type *ValTy,
Value *createSimpleTargetReduction(Intrinsic::ID RdxID, Type *ValTy,
ArrayRef<Value *> VecOpArray,
const Twine &Name = Twine());
};
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Transforms/Utils/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,10 @@ bool canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,
SinkAndHoistLICMFlags &LICMFlags,
OptimizationRemarkEmitter *ORE = nullptr);

/// Returns the llvm.vector.reduce intrinsic that corresponds to the recurrence
/// kind.
constexpr Intrinsic::ID getReductionIntrinsicID(RecurKind RK);

/// Returns the arithmetic instruction opcode used when expanding a reduction.
unsigned getArithmeticReductionInstruction(Intrinsic::ID RdxID);

Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/IR/IntrinsicInst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,25 @@ Intrinsic::ID VPIntrinsic::getForOpcode(unsigned IROPC) {
return Intrinsic::not_intrinsic;
}

constexpr static Intrinsic::ID getForIntrinsic(Intrinsic::ID Id) {
if (::isVPIntrinsic(Id))
return Id;

switch (Id) {
default:
break;
#define BEGIN_REGISTER_VP_INTRINSIC(VPID, ...) break;
#define VP_PROPERTY_FUNCTIONAL_INTRINSIC(INTRIN) case Intrinsic::INTRIN:
#define END_REGISTER_VP_INTRINSIC(VPID) return Intrinsic::VPID;
#include "llvm/IR/VPIntrinsics.def"
}
return Intrinsic::not_intrinsic;
}

Intrinsic::ID VPIntrinsic::getForIntrinsic(Intrinsic::ID Id) {
return ::getForIntrinsic(Id);
}

bool VPIntrinsic::canIgnoreVectorLengthParam() const {
using namespace PatternMatch;

Expand Down
57 changes: 5 additions & 52 deletions llvm/lib/IR/VectorBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,60 +60,13 @@ Value *VectorBuilder::createVectorInstruction(unsigned Opcode, Type *ReturnTy,
return createVectorInstructionImpl(VPID, ReturnTy, InstOpArray, Name);
}

Value *VectorBuilder::createSimpleTargetReduction(RecurKind Kind, Type *ValTy,
Value *VectorBuilder::createSimpleTargetReduction(Intrinsic::ID RdxID,
Type *ValTy,
ArrayRef<Value *> InstOpArray,
const Twine &Name) {
Intrinsic::ID VPID;
switch (Kind) {
case RecurKind::Add:
VPID = Intrinsic::vp_reduce_add;
break;
case RecurKind::Mul:
VPID = Intrinsic::vp_reduce_mul;
break;
case RecurKind::And:
VPID = Intrinsic::vp_reduce_and;
break;
case RecurKind::Or:
VPID = Intrinsic::vp_reduce_or;
break;
case RecurKind::Xor:
VPID = Intrinsic::vp_reduce_xor;
break;
case RecurKind::FMulAdd:
case RecurKind::FAdd:
VPID = Intrinsic::vp_reduce_fadd;
break;
case RecurKind::FMul:
VPID = Intrinsic::vp_reduce_fmul;
break;
case RecurKind::SMax:
VPID = Intrinsic::vp_reduce_smax;
break;
case RecurKind::SMin:
VPID = Intrinsic::vp_reduce_smin;
break;
case RecurKind::UMax:
VPID = Intrinsic::vp_reduce_umax;
break;
case RecurKind::UMin:
VPID = Intrinsic::vp_reduce_umin;
break;
case RecurKind::FMax:
VPID = Intrinsic::vp_reduce_fmax;
break;
case RecurKind::FMin:
VPID = Intrinsic::vp_reduce_fmin;
break;
case RecurKind::FMaximum:
VPID = Intrinsic::vp_reduce_fmaximum;
break;
case RecurKind::FMinimum:
VPID = Intrinsic::vp_reduce_fminimum;
break;
default:
llvm_unreachable("No VPIntrinsic for this reduction");
}
auto VPID = VPIntrinsic::getForIntrinsic(RdxID);
assert(VPReductionIntrinsic::isVPReduction(VPID) &&
"No VPIntrinsic for this reduction");
return createVectorInstructionImpl(VPID, ValTy, InstOpArray, Name);
}

Expand Down
44 changes: 42 additions & 2 deletions llvm/lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,44 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
return true;
}

constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
switch (RK) {
default:
llvm_unreachable("Unexpected recurrence kind");
case RecurKind::Add:
return Intrinsic::vector_reduce_add;
case RecurKind::Mul:
return Intrinsic::vector_reduce_mul;
case RecurKind::And:
return Intrinsic::vector_reduce_and;
case RecurKind::Or:
return Intrinsic::vector_reduce_or;
case RecurKind::Xor:
return Intrinsic::vector_reduce_xor;
case RecurKind::FMulAdd:
case RecurKind::FAdd:
return Intrinsic::vector_reduce_fadd;
case RecurKind::FMul:
return Intrinsic::vector_reduce_fmul;
case RecurKind::SMax:
return Intrinsic::vector_reduce_smax;
case RecurKind::SMin:
return Intrinsic::vector_reduce_smin;
case RecurKind::UMax:
return Intrinsic::vector_reduce_umax;
case RecurKind::UMin:
return Intrinsic::vector_reduce_umin;
case RecurKind::FMax:
return Intrinsic::vector_reduce_fmax;
case RecurKind::FMin:
return Intrinsic::vector_reduce_fmin;
case RecurKind::FMaximum:
return Intrinsic::vector_reduce_fmaximum;
case RecurKind::FMinimum:
return Intrinsic::vector_reduce_fminimum;
}
}

unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
switch (RdxID) {
case Intrinsic::vector_reduce_fadd:
Expand Down Expand Up @@ -1215,12 +1253,13 @@ Value *llvm::createSimpleTargetReduction(VectorBuilder &VBuilder, Value *Src,
RecurKind Kind = Desc.getRecurrenceKind();
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
"AnyOf reduction is not supported.");
Intrinsic::ID Id = getReductionIntrinsicID(Kind);
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
Value *Iden =
Desc.getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
Value *Ops[] = {Iden, Src};
return VBuilder.createSimpleTargetReduction(Kind, SrcTy, Ops);
return VBuilder.createSimpleTargetReduction(Id, SrcTy, Ops);
}

Value *llvm::createTargetReduction(IRBuilderBase &B,
Expand Down Expand Up @@ -1260,9 +1299,10 @@ Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");

Intrinsic::ID Id = getReductionIntrinsicID(RecurKind::FAdd);
auto *SrcTy = cast<VectorType>(Src->getType());
Value *Ops[] = {Start, Src};
return VBuilder.createSimpleTargetReduction(RecurKind::FAdd, SrcTy, Ops);
return VBuilder.createSimpleTargetReduction(Id, SrcTy, Ops);
}

void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
Expand Down
53 changes: 53 additions & 0 deletions llvm/unittests/IR/VPIntrinsicTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,59 @@ TEST_F(VPIntrinsicTest, IntrinsicIDRoundTrip) {
ASSERT_NE(FullTripCounts, 0u);
}

/// Check that going from intrinsic to VP intrinsic and back results in the same
/// intrinsic.
TEST_F(VPIntrinsicTest, IntrinsicToVPRoundTrip) {
bool IsFullTrip = false;
Intrinsic::ID IntrinsicID = Intrinsic::not_intrinsic + 1;
for (; IntrinsicID < Intrinsic::num_intrinsics; IntrinsicID++) {
Intrinsic::ID VPID = VPIntrinsic::getForIntrinsic(IntrinsicID);
// No equivalent VP intrinsic available.
if (VPID == Intrinsic::not_intrinsic)
continue;

// Return itself if passed intrinsic ID is VP intrinsic.
if (VPIntrinsic::isVPIntrinsic(IntrinsicID)) {
ASSERT_EQ(IntrinsicID, VPID);
continue;
}

std::optional<Intrinsic::ID> RoundTripIntrinsicID =
VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID);
// No equivalent non-predicated intrinsic available.
if (!RoundTripIntrinsicID)
continue;

ASSERT_EQ(*RoundTripIntrinsicID, IntrinsicID);
IsFullTrip = true;
}
ASSERT_TRUE(IsFullTrip);
}

/// Check that going from VP intrinsic to equivalent non-predicated intrinsic
/// and back results in the same intrinsic.
TEST_F(VPIntrinsicTest, VPToNonPredIntrinsicRoundTrip) {
std::unique_ptr<Module> M = createVPDeclarationModule();
assert(M);

bool IsFullTrip = false;
for (const auto &VPDecl : *M) {
auto VPID = VPDecl.getIntrinsicID();
std::optional<Intrinsic::ID> NonPredID =
VPIntrinsic::getFunctionalIntrinsicIDForVP(VPID);

// No equivalent non-predicated intrinsic available
if (!NonPredID)
continue;

Intrinsic::ID RoundTripVPID = VPIntrinsic::getForIntrinsic(*NonPredID);

ASSERT_EQ(RoundTripVPID, VPID);
IsFullTrip = true;
}
ASSERT_TRUE(IsFullTrip);
}

/// Check that VPIntrinsic::getDeclarationForParams works.
TEST_F(VPIntrinsicTest, VPIntrinsicDeclarationForParams) {
std::unique_ptr<Module> M = createVPDeclarationModule();
Expand Down

0 comments on commit 6d12b3f

Please sign in to comment.