Skip to content

Commit

Permalink
Reland "[LoopVectorizer] Add support for partial reductions" with non…
Browse files Browse the repository at this point in the history
…-phi operand fix. (llvm#121744)

This relands the reverted llvm#120721 with a fix for cases where neither
reduction operand are the reduction phi. Only
6311423 and
6311423 are new on top of the reverted
PR.

---------

Co-authored-by: Nicholas Guy <[email protected]>
  • Loading branch information
2 people authored and kazutakahirata committed Jan 13, 2025
1 parent d04031a commit 6b23b48
Show file tree
Hide file tree
Showing 17 changed files with 4,588 additions and 31 deletions.
44 changes: 44 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ typedef TargetTransformInfo TTI;
/// for IR-level transformations.
class TargetTransformInfo {
public:
enum PartialReductionExtendKind { PR_None, PR_SignExtend, PR_ZeroExtend };

/// Get the kind of extension that an instruction represents.
static PartialReductionExtendKind
getPartialReductionExtendKind(Instruction *I);

/// Construct a TTI object using a type implementing the \c Concept
/// API below.
///
Expand Down Expand Up @@ -1280,6 +1286,20 @@ class TargetTransformInfo {
/// \return if target want to issue a prefetch in address space \p AS.
bool shouldPrefetchAddressSpace(unsigned AS) const;

/// \return The cost of a partial reduction, which is a reduction from a
/// vector to another vector with fewer elements of larger size. They are
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
/// takes an accumulator and a binary operation operand that itself is fed by
/// two extends. An example of an operation that uses a partial reduction is a
/// dot product, which reduces two vectors to another of 4 times fewer and 4
/// times larger elements.
InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
Type *AccumType, ElementCount VF,
PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp = std::nullopt) const;

/// \return The maximum interleave factor that any transform should try to
/// perform for this target. This number depends on the level of parallelism
/// and the number of execution units in the CPU.
Expand Down Expand Up @@ -2107,6 +2127,20 @@ class TargetTransformInfo::Concept {
/// \return if target want to issue a prefetch in address space \p AS.
virtual bool shouldPrefetchAddressSpace(unsigned AS) const = 0;

/// \return The cost of a partial reduction, which is a reduction from a
/// vector to another vector with fewer elements of larger size. They are
/// represented by the llvm.experimental.partial.reduce.add intrinsic, which
/// takes an accumulator and a binary operation operand that itself is fed by
/// two extends. An example of an operation that uses a partial reduction is a
/// dot product, which reduces two vectors to another of 4 times fewer and 4
/// times larger elements.
virtual InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
Type *AccumType, ElementCount VF,
PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp) const = 0;

virtual unsigned getMaxInterleaveFactor(ElementCount VF) = 0;
virtual InstructionCost getArithmeticInstrCost(
unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
Expand Down Expand Up @@ -2786,6 +2820,16 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.shouldPrefetchAddressSpace(AS);
}

InstructionCost getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp = std::nullopt) const override {
return Impl.getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
AccumType, VF, OpAExtend, OpBExtend,
BinOp);
}

unsigned getMaxInterleaveFactor(ElementCount VF) override {
return Impl.getMaxInterleaveFactor(VF);
}
Expand Down
9 changes: 9 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,15 @@ class TargetTransformInfoImplBase {
bool enableWritePrefetching() const { return false; }
bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }

InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp = std::nullopt) const {
return InstructionCost::getInvalid();
}

unsigned getMaxInterleaveFactor(ElementCount VF) const { return 1; }

InstructionCost getArithmeticInstrCost(
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,15 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
return TTIImpl->shouldPrefetchAddressSpace(AS);
}

InstructionCost TargetTransformInfo::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
AccumType, VF, OpAExtend, OpBExtend,
BinOp);
}

unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
return TTIImpl->getMaxInterleaveFactor(VF);
}
Expand Down Expand Up @@ -974,6 +983,15 @@ InstructionCost TargetTransformInfo::getShuffleCost(
return Cost;
}

TargetTransformInfo::PartialReductionExtendKind
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
if (isa<SExtInst>(I))
return PR_SignExtend;
if (isa<ZExtInst>(I))
return PR_ZeroExtend;
return PR_None;
}

TTI::CastContextHint
TargetTransformInfo::getCastContextHint(const Instruction *I) {
if (!I)
Expand Down
63 changes: 63 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/Support/InstructionCost.h"
#include <cstdint>
#include <optional>

Expand Down Expand Up @@ -357,6 +358,68 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
return BaseT::isLegalNTLoad(DataType, Alignment);
}

InstructionCost
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
Type *AccumType, ElementCount VF,
TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend,
std::optional<unsigned> BinOp) const {

InstructionCost Invalid = InstructionCost::getInvalid();
InstructionCost Cost(TTI::TCC_Basic);

if (Opcode != Instruction::Add)
return Invalid;

if (InputTypeA != InputTypeB)
return Invalid;

EVT InputEVT = EVT::getEVT(InputTypeA);
EVT AccumEVT = EVT::getEVT(AccumType);

if (VF.isScalable() && !ST->isSVEorStreamingSVEAvailable())
return Invalid;
if (VF.isFixed() && (!ST->isNeonAvailable() || !ST->hasDotProd()))
return Invalid;

if (InputEVT == MVT::i8) {
switch (VF.getKnownMinValue()) {
default:
return Invalid;
case 8:
if (AccumEVT == MVT::i32)
Cost *= 2;
else if (AccumEVT != MVT::i64)
return Invalid;
break;
case 16:
if (AccumEVT == MVT::i64)
Cost *= 2;
else if (AccumEVT != MVT::i32)
return Invalid;
break;
}
} else if (InputEVT == MVT::i16) {
// FIXME: Allow i32 accumulator but increase cost, as we would extend
// it to i64.
if (VF.getKnownMinValue() != 8 || AccumEVT != MVT::i64)
return Invalid;
} else
return Invalid;

// AArch64 supports lowering mixed extensions to a usdot but only if the
// i8mm or sve/streaming features are available.
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
!ST->isSVEorStreamingSVEAvailable()))
return Invalid;

if (!BinOp || *BinOp != Instruction::Mul)
return Invalid;

return Cost;
}

bool enableOrderedReductions() const { return true; }

InstructionCost getInterleavedMemoryOpCost(
Expand Down
141 changes: 136 additions & 5 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7531,6 +7531,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
}
continue;
}
// The VPlan-based cost model is more accurate for partial reduction and
// comparing against the legacy cost isn't desirable.
if (isa<VPPartialReductionRecipe>(&R))
return true;
if (Instruction *UI = GetInstructionForCost(&R))
SeenInstrs.insert(UI);
}
Expand Down Expand Up @@ -8751,6 +8755,105 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
return Recipe;
}

/// Find all possible partial reductions in the loop and track all of those that
/// are valid so recipes can be formed later.
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// Find all possible partial reductions.
SmallVector<std::pair<PartialReductionChain, unsigned>, 1>
PartialReductionChains;
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars())
if (std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
getScaledReduction(Phi, RdxDesc, Range))
PartialReductionChains.push_back(*Pair);

// A partial reduction is invalid if any of its extends are used by
// something that isn't another partial reduction. This is because the
// extends are intended to be lowered along with the reduction itself.

// Build up a set of partial reduction bin ops for efficient use checking.
SmallSet<User *, 4> PartialReductionBinOps;
for (const auto &[PartialRdx, _] : PartialReductionChains)
PartialReductionBinOps.insert(PartialRdx.BinOp);

auto ExtendIsOnlyUsedByPartialReductions =
[&PartialReductionBinOps](Instruction *Extend) {
return all_of(Extend->users(), [&](const User *U) {
return PartialReductionBinOps.contains(U);
});
};

// Check if each use of a chain's two extends is a partial reduction
// and only add those that don't have non-partial reduction users.
for (auto Pair : PartialReductionChains) {
PartialReductionChain Chain = Pair.first;
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB))
ScaledReductionExitInstrs.insert(std::make_pair(Chain.Reduction, Pair));
}
}

std::optional<std::pair<PartialReductionChain, unsigned>>
VPRecipeBuilder::getScaledReduction(PHINode *PHI,
const RecurrenceDescriptor &Rdx,
VFRange &Range) {
// TODO: Allow scaling reductions when predicating. The select at
// the end of the loop chooses between the phi value and most recent
// reduction result, both of which have different VFs to the active lane
// mask when scaling.
if (CM.blockNeedsPredicationForAnyReason(Rdx.getLoopExitInstr()->getParent()))
return std::nullopt;

auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr());
if (!Update)
return std::nullopt;

Value *Op = Update->getOperand(0);
Value *PhiOp = Update->getOperand(1);
if (Op == PHI) {
Op = Update->getOperand(1);
PhiOp = Update->getOperand(0);
}
if (PhiOp != PHI)
return std::nullopt;

auto *BinOp = dyn_cast<BinaryOperator>(Op);
if (!BinOp || !BinOp->hasOneUse())
return std::nullopt;

using namespace llvm::PatternMatch;
Value *A, *B;
if (!match(BinOp->getOperand(0), m_ZExtOrSExt(m_Value(A))) ||
!match(BinOp->getOperand(1), m_ZExtOrSExt(m_Value(B))))
return std::nullopt;

Instruction *ExtA = cast<Instruction>(BinOp->getOperand(0));
Instruction *ExtB = cast<Instruction>(BinOp->getOperand(1));

TTI::PartialReductionExtendKind OpAExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtA);
TTI::PartialReductionExtendKind OpBExtend =
TargetTransformInfo::getPartialReductionExtendKind(ExtB);

PartialReductionChain Chain(Rdx.getLoopExitInstr(), ExtA, ExtB, BinOp);

unsigned TargetScaleFactor =
PHI->getType()->getPrimitiveSizeInBits().getKnownScalarFactor(
A->getType()->getPrimitiveSizeInBits());

if (LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
InstructionCost Cost = TTI->getPartialReductionCost(
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
VF, OpAExtend, OpBExtend,
std::make_optional(BinOp->getOpcode()));
return Cost.isValid();
},
Range))
return std::make_pair(Chain, TargetScaleFactor);

return std::nullopt;
}

VPRecipeBase *
VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
ArrayRef<VPValue *> Operands,
Expand All @@ -8775,9 +8878,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
Legal->getReductionVars().find(Phi)->second;
assert(RdxDesc.getRecurrenceStartValue() ==
Phi->getIncomingValueForBlock(OrigLoop->getLoopPreheader()));
PhiRecipe = new VPReductionPHIRecipe(Phi, RdxDesc, *StartV,
CM.isInLoopReduction(Phi),
CM.useOrderedReductions(RdxDesc));

// If the PHI is used by a partial reduction, set the scale factor.
std::optional<std::pair<PartialReductionChain, unsigned>> Pair =
getScaledReductionForInstr(RdxDesc.getLoopExitInstr());
unsigned ScaleFactor = Pair ? Pair->second : 1;
PhiRecipe = new VPReductionPHIRecipe(
Phi, RdxDesc, *StartV, CM.isInLoopReduction(Phi),
CM.useOrderedReductions(RdxDesc), ScaleFactor);
} else {
// TODO: Currently fixed-order recurrences are modeled as chains of
// first-order recurrences. If there are no users of the intermediate
Expand Down Expand Up @@ -8809,6 +8917,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
return tryToWidenMemory(Instr, Operands, Range);

if (getScaledReductionForInstr(Instr))
return tryToCreatePartialReduction(Instr, Operands);

if (!shouldWiden(Instr, Range))
return nullptr;

Expand All @@ -8829,6 +8940,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
return tryToWiden(Instr, Operands, VPBB);
}

VPRecipeBase *
VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
ArrayRef<VPValue *> Operands) {
assert(Operands.size() == 2 &&
"Unexpected number of operands for partial reduction");

VPValue *BinOp = Operands[0];
VPValue *Phi = Operands[1];
if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe()))
std::swap(BinOp, Phi);

return new VPPartialReductionRecipe(Reduction->getOpcode(), BinOp, Phi,
Reduction);
}

void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
ElementCount MaxVF) {
assert(OrigLoop->isInnermost() && "Inner loop expected.");
Expand Down Expand Up @@ -9252,7 +9378,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
addCanonicalIVRecipes(*Plan, Legal->getWidestInductionType(), HasNUW, DL);

VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
Builder);

// ---------------------------------------------------------------------------
// Pre-construction: record ingredients whose recipes we'll need to further
Expand Down Expand Up @@ -9298,6 +9425,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
bool NeedsBlends = BB != HeaderBB && !BB->phis().empty();
return Legal->blockNeedsPredication(BB) || NeedsBlends;
});

RecipeBuilder.collectScaledReductions(Range);

auto *MiddleVPBB = Plan->getMiddleBlock();
VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi();
for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) {
Expand Down Expand Up @@ -9521,7 +9651,8 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {

// Collect mapping of IR header phis to header phi recipes, to be used in
// addScalarResumePhis.
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
Builder);
for (auto &R : Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) {
if (isa<VPCanonicalIVPHIRecipe>(&R))
continue;
Expand Down
Loading

0 comments on commit 6b23b48

Please sign in to comment.