Skip to content

Commit

Permalink
[LV][EVL] Support in-loop reduction using tail folding with EVL. (llv…
Browse files Browse the repository at this point in the history
…m#90184)

Following from llvm#87816, add VPReductionEVLRecipe to describe vector
predication reduction.

Address one of TODOs from llvm#76172.
  • Loading branch information
Mel-Chen authored Jul 16, 2024
1 parent 5d12fa7 commit 4eb30cf
Show file tree
Hide file tree
Showing 18 changed files with 5,344 additions and 96 deletions.
15 changes: 15 additions & 0 deletions llvm/include/llvm/IR/VectorBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef LLVM_IR_VECTORBUILDER_H
#define LLVM_IR_VECTORBUILDER_H

#include <llvm/Analysis/IVDescriptors.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/InstrTypes.h>
#include <llvm/IR/Instruction.h>
Expand Down Expand Up @@ -57,6 +58,11 @@ class VectorBuilder {
return RetType();
}

/// Helper function for creating VP intrinsic call.
Value *createVectorInstructionImpl(Intrinsic::ID VPID, Type *ReturnTy,
ArrayRef<Value *> VecOpArray,
const Twine &Name = Twine());

public:
VectorBuilder(IRBuilderBase &Builder,
Behavior ErrorHandling = Behavior::ReportAndAbort)
Expand Down Expand Up @@ -92,6 +98,15 @@ class VectorBuilder {
Value *createVectorInstruction(unsigned Opcode, Type *ReturnTy,
ArrayRef<Value *> VecOpArray,
const Twine &Name = Twine());

/// Emit a VP reduction intrinsic call for recurrence kind.
/// \param Kind The kind of recurrence
/// \param ValTy The type of operand which the reduction operation is
/// performed.
/// \param VecOpArray The operand list.
Value *createSimpleTargetReduction(RecurKind Kind, Type *ValTy,
ArrayRef<Value *> VecOpArray,
const Twine &Name = Twine());
};

} // namespace llvm
Expand Down
10 changes: 10 additions & 0 deletions llvm/include/llvm/Transforms/Utils/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "llvm/Analysis/IVDescriptors.h"
#include "llvm/Analysis/LoopAccessAnalysis.h"
#include "llvm/IR/VectorBuilder.h"
#include "llvm/Transforms/Utils/ValueMapper.h"

namespace llvm {
Expand Down Expand Up @@ -394,6 +395,10 @@ Value *getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op,
/// Fast-math-flags are propagated using the IRBuilder's setting.
Value *createSimpleTargetReduction(IRBuilderBase &B, Value *Src,
RecurKind RdxKind);
/// Overloaded function to generate vector-predication intrinsics for target
/// reduction.
Value *createSimpleTargetReduction(VectorBuilder &VB, Value *Src,
const RecurrenceDescriptor &Desc);

/// Create a target reduction of the given vector \p Src for a reduction of the
/// kind RecurKind::IAnyOf or RecurKind::FAnyOf. The reduction operation is
Expand All @@ -414,6 +419,11 @@ Value *createTargetReduction(IRBuilderBase &B, const RecurrenceDescriptor &Desc,
Value *createOrderedReduction(IRBuilderBase &B,
const RecurrenceDescriptor &Desc, Value *Src,
Value *Start);
/// Overloaded function to generate vector-predication intrinsics for ordered
/// reduction.
Value *createOrderedReduction(VectorBuilder &VB,
const RecurrenceDescriptor &Desc, Value *Src,
Value *Start);

/// Get the intersection (logical and) of all of the potential IR flags
/// of each scalar operation (VL) that will be converted into a vector (I).
Expand Down
63 changes: 63 additions & 0 deletions llvm/lib/IR/VectorBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,70 @@ Value *VectorBuilder::createVectorInstruction(unsigned Opcode, Type *ReturnTy,
auto VPID = VPIntrinsic::getForOpcode(Opcode);
if (VPID == Intrinsic::not_intrinsic)
return returnWithError<Value *>("No VPIntrinsic for this opcode");
return createVectorInstructionImpl(VPID, ReturnTy, InstOpArray, Name);
}

Value *VectorBuilder::createSimpleTargetReduction(RecurKind Kind, Type *ValTy,
ArrayRef<Value *> InstOpArray,
const Twine &Name) {
Intrinsic::ID VPID;
switch (Kind) {
case RecurKind::Add:
VPID = Intrinsic::vp_reduce_add;
break;
case RecurKind::Mul:
VPID = Intrinsic::vp_reduce_mul;
break;
case RecurKind::And:
VPID = Intrinsic::vp_reduce_and;
break;
case RecurKind::Or:
VPID = Intrinsic::vp_reduce_or;
break;
case RecurKind::Xor:
VPID = Intrinsic::vp_reduce_xor;
break;
case RecurKind::FMulAdd:
case RecurKind::FAdd:
VPID = Intrinsic::vp_reduce_fadd;
break;
case RecurKind::FMul:
VPID = Intrinsic::vp_reduce_fmul;
break;
case RecurKind::SMax:
VPID = Intrinsic::vp_reduce_smax;
break;
case RecurKind::SMin:
VPID = Intrinsic::vp_reduce_smin;
break;
case RecurKind::UMax:
VPID = Intrinsic::vp_reduce_umax;
break;
case RecurKind::UMin:
VPID = Intrinsic::vp_reduce_umin;
break;
case RecurKind::FMax:
VPID = Intrinsic::vp_reduce_fmax;
break;
case RecurKind::FMin:
VPID = Intrinsic::vp_reduce_fmin;
break;
case RecurKind::FMaximum:
VPID = Intrinsic::vp_reduce_fmaximum;
break;
case RecurKind::FMinimum:
VPID = Intrinsic::vp_reduce_fminimum;
break;
default:
llvm_unreachable("No VPIntrinsic for this reduction");
}
return createVectorInstructionImpl(VPID, ValTy, InstOpArray, Name);
}

Value *VectorBuilder::createVectorInstructionImpl(Intrinsic::ID VPID,
Type *ReturnTy,
ArrayRef<Value *> InstOpArray,
const Twine &Name) {
auto MaskPosOpt = VPIntrinsic::getMaskParamPos(VPID);
auto VLenPosOpt = VPIntrinsic::getVectorLengthParamPos(VPID);
size_t NumInstParams = InstOpArray.size();
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,19 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,
}
}

Value *llvm::createSimpleTargetReduction(VectorBuilder &VBuilder, Value *Src,
const RecurrenceDescriptor &Desc) {
RecurKind Kind = Desc.getRecurrenceKind();
assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
"AnyOf reduction is not supported.");
auto *SrcTy = cast<VectorType>(Src->getType());
Type *SrcEltTy = SrcTy->getElementType();
Value *Iden =
Desc.getRecurrenceIdentity(Kind, SrcEltTy, Desc.getFastMathFlags());
Value *Ops[] = {Iden, Src};
return VBuilder.createSimpleTargetReduction(Kind, SrcTy, Ops);
}

Value *llvm::createTargetReduction(IRBuilderBase &B,
const RecurrenceDescriptor &Desc, Value *Src,
PHINode *OrigPhi) {
Expand Down Expand Up @@ -1220,6 +1233,20 @@ Value *llvm::createOrderedReduction(IRBuilderBase &B,
return B.CreateFAddReduce(Start, Src);
}

Value *llvm::createOrderedReduction(VectorBuilder &VBuilder,
const RecurrenceDescriptor &Desc,
Value *Src, Value *Start) {
assert((Desc.getRecurrenceKind() == RecurKind::FAdd ||
Desc.getRecurrenceKind() == RecurKind::FMulAdd) &&
"Unexpected reduction kind");
assert(Src->getType()->isVectorTy() && "Expected a vector type");
assert(!Start->getType()->isVectorTy() && "Expected a scalar type");

auto *SrcTy = cast<VectorType>(Src->getType());
Value *Ops[] = {Start, Src};
return VBuilder.createSimpleTargetReduction(RecurKind::FAdd, SrcTy, Ops);
}

void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
bool IncludeWrapFlags) {
auto *VecOp = dyn_cast<Instruction>(I);
Expand Down
4 changes: 1 addition & 3 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1516,9 +1516,7 @@ class LoopVectorizationCostModel {
TTI.hasActiveVectorLength(0, nullptr, Align()) &&
!EnableVPlanNativePath &&
// FIXME: implement support for max safe dependency distance.
Legal->isSafeForAnyVectorWidth() &&
// FIXME: remove this once reductions are supported.
Legal->getReductionVars().empty();
Legal->isSafeForAnyVectorWidth();
if (!EVLIsLegal) {
// If for some reason EVL mode is unsupported, fallback to
// DataWithoutLaneMask to try to vectorize the loop with folded tail
Expand Down
84 changes: 76 additions & 8 deletions llvm/lib/Transforms/Vectorize/VPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
case VPRecipeBase::VPEVLBasedIVPHISC:
case VPRecipeBase::VPExpandSCEVSC:
case VPRecipeBase::VPInstructionSC:
case VPRecipeBase::VPReductionEVLSC:
case VPRecipeBase::VPReductionSC:
case VPRecipeBase::VPReplicateSC:
case VPRecipeBase::VPScalarIVStepsSC:
Expand Down Expand Up @@ -2171,17 +2172,27 @@ class VPReductionRecipe : public VPSingleDefRecipe {
/// The recurrence decriptor for the reduction in question.
const RecurrenceDescriptor &RdxDesc;
bool IsOrdered;
/// Whether the reduction is conditional.
bool IsConditional = false;

protected:
VPReductionRecipe(const unsigned char SC, const RecurrenceDescriptor &R,
Instruction *I, ArrayRef<VPValue *> Operands,
VPValue *CondOp, bool IsOrdered)
: VPSingleDefRecipe(SC, Operands, I), RdxDesc(R), IsOrdered(IsOrdered) {
if (CondOp) {
IsConditional = true;
addOperand(CondOp);
}
}

public:
VPReductionRecipe(const RecurrenceDescriptor &R, Instruction *I,
VPValue *ChainOp, VPValue *VecOp, VPValue *CondOp,
bool IsOrdered)
: VPSingleDefRecipe(VPDef::VPReductionSC,
ArrayRef<VPValue *>({ChainOp, VecOp}), I),
RdxDesc(R), IsOrdered(IsOrdered) {
if (CondOp)
addOperand(CondOp);
}
: VPReductionRecipe(VPDef::VPReductionSC, R, I,
ArrayRef<VPValue *>({ChainOp, VecOp}), CondOp,
IsOrdered) {}

~VPReductionRecipe() override = default;

Expand All @@ -2190,7 +2201,15 @@ class VPReductionRecipe : public VPSingleDefRecipe {
getVecOp(), getCondOp(), IsOrdered);
}

VP_CLASSOF_IMPL(VPDef::VPReductionSC)
static inline bool classof(const VPRecipeBase *R) {
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
}

static inline bool classof(const VPUser *U) {
auto *R = dyn_cast<VPRecipeBase>(U);
return R && classof(R);
}

/// Generate the reduction in the loop
void execute(VPTransformState &State) override;
Expand All @@ -2201,13 +2220,62 @@ class VPReductionRecipe : public VPSingleDefRecipe {
VPSlotTracker &SlotTracker) const override;
#endif

/// Return the recurrence decriptor for the in-loop reduction.
const RecurrenceDescriptor &getRecurrenceDescriptor() const {
return RdxDesc;
}
/// Return true if the in-loop reduction is ordered.
bool isOrdered() const { return IsOrdered; };
/// Return true if the in-loop reduction is conditional.
bool isConditional() const { return IsConditional; };
/// The VPValue of the scalar Chain being accumulated.
VPValue *getChainOp() const { return getOperand(0); }
/// The VPValue of the vector value to be reduced.
VPValue *getVecOp() const { return getOperand(1); }
/// The VPValue of the condition for the block.
VPValue *getCondOp() const {
return getNumOperands() > 2 ? getOperand(2) : nullptr;
return isConditional() ? getOperand(getNumOperands() - 1) : nullptr;
}
};

/// A recipe to represent inloop reduction operations with vector-predication
/// intrinsics, performing a reduction on a vector operand with the explicit
/// vector length (EVL) into a scalar value, and adding the result to a chain.
/// The Operands are {ChainOp, VecOp, EVL, [Condition]}.
class VPReductionEVLRecipe : public VPReductionRecipe {
public:
VPReductionEVLRecipe(VPReductionRecipe *R, VPValue *EVL, VPValue *CondOp)
: VPReductionRecipe(
VPDef::VPReductionEVLSC, R->getRecurrenceDescriptor(),
cast_or_null<Instruction>(R->getUnderlyingValue()),
ArrayRef<VPValue *>({R->getChainOp(), R->getVecOp(), EVL}), CondOp,
R->isOrdered()) {}

~VPReductionEVLRecipe() override = default;

VPReductionEVLRecipe *clone() override {
llvm_unreachable("cloning not implemented yet");
}

VP_CLASSOF_IMPL(VPDef::VPReductionEVLSC)

/// Generate the reduction in the loop
void execute(VPTransformState &State) override;

#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
/// Print the recipe.
void print(raw_ostream &O, const Twine &Indent,
VPSlotTracker &SlotTracker) const override;
#endif

/// The VPValue of the explicit vector length.
VPValue *getEVL() const { return getOperand(2); }

/// Returns true if the recipe only uses the first lane of operand \p Op.
bool onlyFirstLaneUsed(const VPValue *Op) const override {
assert(is_contained(operands(), Op) &&
"Op must be an operand of the recipe");
return Op == getEVL();
}
};

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
[](const VPScalarCastRecipe *R) { return R->getResultType(); })
.Case<VPExpandSCEVRecipe>([](const VPExpandSCEVRecipe *R) {
return R->getSCEV()->getType();
})
.Case<VPReductionRecipe>([this](const auto *R) {
return inferScalarType(R->getChainOp());
});

assert(ResultTy && "could not infer type for the given VPValue");
Expand Down
Loading

0 comments on commit 4eb30cf

Please sign in to comment.