-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering for SMEM-to-MMAv3 DotOp Copy #5003
[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering for SMEM-to-MMAv3 DotOp Copy #5003
Changes from 1 commit
f0fe49d
d2fff26
7308447
8aef99b
32651e9
25fc6be
27b2333
d5932b2
26e7407
a40e519
73363cf
b3dc4f0
c1272f1
5ce5628
20f9ba0
882aefc
83a1bde
d1543e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -264,6 +264,30 @@ DotOpMmaV3SmemLoader loadB(const LLVMTypeConverter *typeConverter, | |
|
||
// Return a vector of Value of the accumulator start at startIndex and pack the | ||
// values into 32bits in case the accumulator is fp16. | ||
// | ||
// `elements` contains all loaded register values for operand A. | ||
// This consists of operand A for possibly multiple wgmma instructions. | ||
// For each wgmma, each warp in a warp group feeds a single "warp matrix" | ||
// Each warp matrix consists of 2x2 "quads". | ||
// Each thread holds several elements in each quad. Right before a wgmma, | ||
// the sum of bitwidth of | ||
// the elements in each quad should add up to 32. | ||
// | ||
// These values are stored unrolled in `elements`. | ||
// The ordering of dimensions is as follows: | ||
// batch (only 1 batch for Hopper currently) | ||
// matM (m-index of the "warp matrix") | ||
// matK (k-index of the "warp matrix") | ||
// quadM (m-index of the "quad" in the core matrix) | ||
// quadK (k-index of the "quad" in the core matrix) | ||
// vecIdx (index of the element in the quad; this is always along the k-dim) | ||
// | ||
// This ordering is decided when a tensor in DotOpEnc is lowered into llvm. | ||
// For WGMMA this happens in both SharedToDotOperand and MMAToDotOperand. | ||
// Thus, both lowerings must obey this above ordering for the below code to be correct. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The decision should be based on the layout definition rather than a convention between different lowering. This comment is a bit misleading and maybe we should more explicitly describe the layout instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's currently nothing in the dot operand layout attributes that would indicate the ordering of matM and matK though, so I assumed it was just implicit logic. I could move this comment to the definition of DotOpEncoding or perhaps remove it altogether to avoid confusion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes the layout is not well documented and/or defined but this is how it should work :) I think moving it to DotOpEncoding is good, this is still valuable in my opinion |
||
// | ||
// Additionally, note that WGMMA expects quadK ordered before quadM (i.e. | ||
// iterate along m-dim first); see loadI and mmaI. | ||
llvm::SmallVector<Value> loadReg(ConversionPatternRewriter &rewriter, | ||
Location loc, | ||
const SmallVector<Value> &elements, | ||
|
@@ -281,20 +305,24 @@ llvm::SmallVector<Value> loadReg(ConversionPatternRewriter &rewriter, | |
} | ||
Type elementType = elements[0].getType(); | ||
int numElemsPer32Bits = 32 / elementType.getIntOrFloatBitWidth(); | ||
assert(numElements == 4 * numElemsPer32Bits); | ||
|
||
// For FP16 and BF16 we need to pack accumulator into 32-bit integers. | ||
int num32BitValues = numElements / numElemsPer32Bits; | ||
llvm::SmallVector<Value> mmaOut(num32BitValues); | ||
llvm::SmallVector<Value> mmaOut(4); | ||
Type packTy = vec_ty(elementType, numElemsPer32Bits); | ||
for (int i = 0; i < num32BitValues; ++i) { | ||
Value pack = rewriter.create<LLVM::UndefOp>(loc, packTy); | ||
for (int j = 0; j < numElemsPer32Bits; ++j) { | ||
Value element = elements[startIndex + i * numElemsPer32Bits + j]; | ||
pack = insert_element(packTy, pack, element, i32_val(j)); | ||
for (int quadK = 0; quadK < 2; quadK++) | ||
for (int quadM = 0; quadM < 2; quadM++) { | ||
int loadI = quadM * 2 + quadK; | ||
int mmaI = quadK * 2 + quadM; | ||
Value pack = rewriter.create<LLVM::UndefOp>(loc, packTy); | ||
for (int j = 0; j < numElemsPer32Bits; ++j) { | ||
Value element = elements[startIndex + loadI * numElemsPer32Bits + j]; | ||
pack = insert_element(packTy, pack, element, i32_val(j)); | ||
} | ||
pack = bitcast(pack, rewriter.getIntegerType(32)); | ||
mmaOut[mmaI] = pack; | ||
} | ||
pack = bitcast(pack, rewriter.getIntegerType(32)); | ||
mmaOut[i] = pack; | ||
} | ||
|
||
return mmaOut; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: following MLIR style we usually don't have braces here