Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering for SMEM-to-MMAv3 DotOp Copy #5003

Merged
merged 18 commits into from
Nov 15, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Lower shared > v3 dotOp & improve hoisting logic
ggengnv committed Nov 12, 2024
commit d2fff2683e7fde40524f09ce5b1a313b4795ed64
53 changes: 34 additions & 19 deletions lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
Original file line number Diff line number Diff line change
@@ -62,7 +62,7 @@ bool canHoistDotOpEncV3(Operation* op) {
// Currently, these instructions are not supported during lowering of
// shared -> dot_operand layout. Not all types and type conversions are
// supported.
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(op))
if (isa<arith::SelectOp>(op))
return false;

// Don't hoist through u1 -> fp casts as they aren't supported in
@@ -369,6 +369,7 @@ struct MMAV3HoistLayoutConversion

// Performs checks for early stop
NvidiaMmaEncodingAttr dstEnc;
Type inputEltTy;
{
auto srcEnc = dyn_cast<BlockedEncodingAttr>(getEncoding(alloc.getSrc()));
dstEnc =
@@ -388,6 +389,7 @@ struct MMAV3HoistLayoutConversion
auto srcTy = dyn_cast<RankedTensorType>(src->getResult(0).getType());
if (!srcTy)
return failure();
inputEltTy = srcTy.getElementType();

if (!canHoistDotOpEncV3(src))
return failure();
@@ -397,7 +399,7 @@ struct MMAV3HoistLayoutConversion
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = [&](Operation *op) {
return (op->getParentRegion() == alloc->getParentRegion()) && !isa<LoadOp, LocalLoadOp>(op)
return (op->getParentRegion() == alloc->getParentRegion()) && !isa<LoadOp, LocalLoadOp, arith::ConstantOp>(op)
&& (op->getNumOperands() != 0); // Ensures all ops in slice have operands
};

@@ -417,18 +419,14 @@ struct MMAV3HoistLayoutConversion
for (auto operand : currOp->getOperands()) {
auto op = operand.getDefiningOp();
if (!slice.contains(op)) {
// TODO that this is overly restrictive. Can add support for ConstantOp and LocalLoad
if (!isa<LoadOp>(op))
if (!isa<LoadOp, arith::ConstantOp>(op))
return failure();

isFrontier = true;
}
}

if (isFrontier) {
if (!isa<LoadOp>(currOp->getOperand(0).getDefiningOp()))
return failure();

auto res = currOp->getResult(0);
if (!isBlockedRankedTensor(res))
return failure();
@@ -444,12 +442,16 @@ struct MMAV3HoistLayoutConversion
if (frontierOps.empty())
return failure();

// convert A operand
auto dotOperandEnc = DotOperandEncodingAttr::get(
dotOp.getContext(), /*opIdx=*/0, dstEnc, /*kWidth=*/0);
dotOp.getContext(), /*opIdx=*/0, dstEnc, inputEltTy);

// For each frontierOp:
// load; frontierOp; ...; warp_group_dot
// -> load; local_alloc; local_load; convert_layout; frontierOp; ...; warp_group_dot
// load; frontierOp; [hoistableOps...]; local_alloc; warp_group_dot
// -> load; local_alloc; local_load; convert_layout; frontierOp; [hoistableOps...]; warp_group_dot
// or...
// constant; frontierOp; [hoistableOps...]; warp_group_dot
// -> constant; convert_layout; frontierOp; [hoistableOps...]; warp_group_dot
for (Operation *frontierOp : frontierOps) {
auto frontierTy = dyn_cast<RankedTensorType>(frontierOp->getResult(0).getType());

@@ -459,17 +461,30 @@ struct MMAV3HoistLayoutConversion
auto operandTy = cast<RankedTensorType>(operand.getType());
auto operandEltTy = operandTy.getElementType();

auto oldAllocTy = alloc.getType();
// TODO(ggengnv) previous encoding (oldAllocTy.getEncoding()) was for shared operand.
// Is it still appropriate for loading into registers?
auto newAllocTy = MemDescType::get(operandTy.getShape(), operandEltTy,
oldAllocTy.getEncoding(), oldAllocTy.getMemorySpace());
auto localAlloc = rewriter.create<LocalAllocOp>(alloc.getLoc(), newAllocTy, operand);
auto localLoad = rewriter.create<LocalLoadOp>(alloc.getLoc(), operandTy, localAlloc);
ConvertLayoutOp cvt;

Type cvtTy = RankedTensorType::get(
operandTy.getShape(), operandTy.getElementType(), dotOperandEnc);
auto cvt = rewriter.create<ConvertLayoutOp>(alloc.getLoc(), cvtTy, localLoad);

if (isa<LoadOp>(operand.getDefiningOp())) {
auto oldAllocTy = alloc.getType();
auto oldAllocEnc = cast<SharedEncodingAttr>(oldAllocTy.getEncoding());

auto newAllocEnc = SharedEncodingAttr::get(
oldAllocEnc.getContext(), dotOperandEnc, operandTy.getShape(),
getOrder(operandTy.getEncoding()),
getCTALayout(operandTy.getEncoding()),
operandTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false);

auto newAllocTy = MemDescType::get(operandTy.getShape(), operandEltTy,
newAllocEnc, oldAllocTy.getMemorySpace());
auto localAlloc = rewriter.create<LocalAllocOp>(alloc.getLoc(), newAllocTy, operand);
auto localLoad = rewriter.create<LocalLoadOp>(alloc.getLoc(), operandTy, localAlloc);
cvt = rewriter.create<ConvertLayoutOp>(alloc.getLoc(), cvtTy, localLoad);
} else {
assert(isa<arith::ConstantOp>(operand.getDefiningOp()));
cvt = rewriter.create<ConvertLayoutOp>(alloc.getLoc(), cvtTy, operand);
}

newOperands.push_back(cvt);
}
@@ -511,13 +526,13 @@ class TritonGPUOptimizeDotOperandsPass
auto ret = pm.run(m);

mlir::RewritePatternSet patterns(context);
patterns.add<MMAV3HoistLayoutConversion>(context);
patterns.add<SwizzleShmemConvert>(context);
if (this->hoistLayoutConversion.getValue()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: following MLIR style we usually don't have braces here

patterns.add<HoistLayoutConversion>(context);
}
patterns.add<FuseTransHopper>(context);
patterns.add<MMAV3UseRegOperand>(context);
patterns.add<MMAV3HoistLayoutConversion>(context);
ConvertLayoutOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
signalPassFailure();
Original file line number Diff line number Diff line change
@@ -36,14 +36,10 @@ class DecomposeLocalLoadToDotOperand
op.getType().getEncoding());
MemDescType srcType = op.getSrc().getType();
auto sharedEncoding = cast<SharedEncodingAttr>(srcType.getEncoding());
if (!dstDotOp)
if (!dstDotOp || !sharedEncoding.getHasLeadingOffset())
return failure();

auto parentEnc = cast<NvidiaMmaEncodingAttr>(dstDotOp.getParent()) ;
if (!parentEnc || parentEnc.getVersionMajor() == 3 || !sharedEncoding.getHasLeadingOffset())
return failure();

RankedTensorType type = op.getType();
auto parentEnc = dstDotOp.getParent();
int numWarps = triton::gpu::getNumWarpsPerCTA(parentEnc);
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
Original file line number Diff line number Diff line change
@@ -264,6 +264,30 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter,

// Return a vector of Value of the accumulator start at startIndex and pack the
// values into 32bits in case the accumulator is fp16.
//
// `elements` contains all loaded register values for operand A.
// This consists of operand A for possibly multiple wgmma instructions.
// For each wgmma, each warp in a warp group feeds a single "warp matrix"
// Each warp matrix consists of 2x2 "quads".
// Each thread holds several elements in each quad. Right before a wgmma,
// the sum of bitwidth of
// the elements in each quad should add up to 32.
//
// These values are stored unrolled in `elements`.
// The ordering of dimensions is as follows:
// batch (only 1 batch for Hopper currently)
// matM (m-index of the "warp matrix")
// matK (k-index of the "warp matrix")
// quadM (m-index of the "quad" in the core matrix)
// quadK (k-index of the "quad" in the core matrix)
// vecIdx (index of the element in the quad; this is always along the k-dim)
//
// This ordering is decided when a tensor in DotOpEnc is lowered into llvm.
// For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand.
// Thus, both lowerings must obey this above ordering for the below code to be correct.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decision should be based on the layout definition rather than a convention between different lowering. This comment is a bit misleading and maybe we should more explicitly describe the layout instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's currently nothing in the dot operand layout attributes that would indicate the ordering of matM and matK though, so I assumed it was just implicit logic. I could move this comment to the definition of DotOpEncoding or perhaps remove it altogether to avoid confusion?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes the layout is not well documented and/or defined but this is how it should work :) I think moving it to DotOpEncoding is good, this is still valuable in my opinion

//
// Additionally, note that WGMMA expects quadK ordered before quadM (i.e.
// iterate along m-dim first); see loadI and mmaI.
llvm::SmallVector<Value> loadReg(ConversionPatternRewriter &rewriter,
Location loc,
const SmallVector<Value> &elements,
@@ -281,20 +305,24 @@ llvm::SmallVector<Value> loadReg(ConversionPatternRewriter &rewriter,
}
Type elementType = elements[0].getType();
int numElemsPer32Bits = 32 / elementType.getIntOrFloatBitWidth();
assert(numElements == 4 * numElemsPer32Bits);

// For FP16 and BF16 we need to pack accumulator into 32-bit integers.
int num32BitValues = numElements / numElemsPer32Bits;
llvm::SmallVector<Value> mmaOut(num32BitValues);
llvm::SmallVector<Value> mmaOut(4);
Type packTy = vec_ty(elementType, numElemsPer32Bits);
for (int i = 0; i < num32BitValues; ++i) {
Value pack = rewriter.create<LLVM::UndefOp>(loc, packTy);
for (int j = 0; j < numElemsPer32Bits; ++j) {
Value element = elements[startIndex + i * numElemsPer32Bits + j];
pack = insert_element(packTy, pack, element, i32_val(j));
for (int quadK = 0; quadK < 2; quadK++)
for (int quadM = 0; quadM < 2; quadM++) {
int loadI = quadM * 2 + quadK;
int mmaI = quadK * 2 + quadM;
Value pack = rewriter.create<LLVM::UndefOp>(loc, packTy);
for (int j = 0; j < numElemsPer32Bits; ++j) {
Value element = elements[startIndex + loadI * numElemsPer32Bits + j];
pack = insert_element(packTy, pack, element, i32_val(j));
}
pack = bitcast(pack, rewriter.getIntegerType(32));
mmaOut[mmaI] = pack;
}
pack = bitcast(pack, rewriter.getIntegerType(32));
mmaOut[i] = pack;
}

return mmaOut;
}