From 56584c468c27dc9492df9c3294897212b4e7255c Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Sat, 2 Nov 2024 21:34:54 +0000 Subject: [PATCH] [BACKEND]Fix DotOperand(Ampere) LinearLayoutConversion (#5038) We also clean a bit `TritonGPU/IR/Dialect.cpp` using some auxiliary functions to make the intentions a bit clearer. We add a few asserts in the `LinearLayoutConversion` to make sure it's clear why we do certain things here and there. We also kill `getCvtOrder`, as it was not used anywhere --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 11 ++ lib/Dialect/TritonGPU/IR/Dialect.cpp | 74 ++++++------- .../TritonGPU/IR/LinearLayoutConversions.cpp | 104 +++++++++++++----- .../DotOpToLLVM/MMAv2.cpp | 16 +-- .../TritonGPU/LinearLayoutConversionsTest.cpp | 41 ++++++- 5 files changed, 166 insertions(+), 80 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 74ea99b58891..cfc00926ddc2 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -130,6 +130,17 @@ unsigned getNumWarpsPerCTA(Attribute layout); unsigned getNumCTAs(Attribute layout); +// Return the order that represents that the batch is in row-major or +// column-major order for a batch of matrices of shape [*, m, n] with +// len(shape) == rank. +SmallVector getMatrixOrder(unsigned rank, bool rowMajor); + +// Return the order that represents that the dot operand is in kMajor +// (contiguous in the inner dimension) or it's contiguous on the outer +// dimension. +SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, + bool kMajor); + bool isExpensiveCat(CatOp cat, Attribute targetEncoding); // Return true if a view between the two types cannot be implemented as a no-op. diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 3b5316ecc0e3..6a820fc2b37a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -235,6 +235,19 @@ static SmallVector eraseOrder(ArrayRef order, return resOrder; } +SmallVector getMatrixOrder(unsigned rank, bool rowMajor) { + // Return the order that represents that the batch is in row-major or + // column-major order for a batch of matrices of shape [*, m, n] with + // len(shape) == rank. + assert(rank >= 2); + SmallVector order(rank); + std::iota(order.rbegin(), order.rend(), 0); + if (!rowMajor) { + std::swap(order[0], order[1]); + } + return order; +} + SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, bool kMajor) { // kMajor: if true, the matrix is fastest-running on k, @@ -244,15 +257,8 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, // batch (if rank == 3) is always the slowest running dimension assert(rank == 2 || rank == 3); assert(opIdx == 0 || opIdx == 1); - SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); - // If opIdx is 1 and kMajor is true, the order is [0, 1] - // (resp. [1, 2, 0] if rank == 3) - // Same if opIdx is 0 and kMajor is false - if (bool(opIdx) == kMajor) { - std::swap(order[0], order[1]); - } - return order; + auto rowMajor = bool(opIdx) != kMajor; + return getMatrixOrder(rank, rowMajor); } SmallVector getWarpOrder(Attribute layout) { @@ -262,20 +268,21 @@ SmallVector getWarpOrder(Attribute layout) { } } auto order = getOrder(layout); - // FIXME: This mmaLayout if should just return - // getOrderForDotOperand(0, order.size(), kMajor=false) - // as mma has the same order as DotOperand(opIdx=0) + // FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's + // M-major This is awkward. Since we can choose any warpOrder in Ampere, we + // should probably choose M-major and change `LinearLayoutConversion.cpp` and + // `MMAv2.cpp` to match. if (auto mmaLayout = dyn_cast(layout)) { if (mmaLayout.isHopper()) { - // Hopper MMA instructions force a warp order of [0, 1]. See docs: + // Hopper MMA instructions force warps to be column-major // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 - auto it = std::find(order.begin(), order.end(), 0); - order.erase(it); - order.insert(order.begin(), 0); + return getMatrixOrder(order.size(), /*rowMajor*/ false); } } else if (auto dotOpLayout = dyn_cast(layout)) { - order = getOrderForDotOperand(dotOpLayout.getOpIdx(), order.size(), - /*kMajor*/ false); + // It's quite weird to talk about warp order when that the warps + // are broadcasted along the K dimension + llvm::report_fatal_error( + "DotOperandEncoding::getWarpOrder not implemented"); } return order; } @@ -285,11 +292,11 @@ SmallVector getOrder(Attribute layout) { return llvm::to_vector(blockedLayout.getOrder()); } if (auto mmaLayout = dyn_cast(layout)) { + // Order doesn't really matter. We just have to be consistent when unpacking + // the elements in the MMAv2/V3 lowerings. We choose row-major auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); - SmallVector order(rank); - std::iota(order.rbegin(), order.rend(), 0); - return order; + return getMatrixOrder(rank, /*rowMajor*/ true); } if (auto dotLayout = dyn_cast(layout)) { auto rank = dotLayout.getWarpsPerCTA().size(); @@ -421,7 +428,7 @@ unsigned getNumWarpsPerCTA(Attribute layout) { else if (auto wmmaLayout = dyn_cast(layout)) warpsPerCTA = wmmaLayout.getWarpsPerCTA(); else if (auto dotLayout = dyn_cast(layout)) - return getNumWarpsPerCTA(dotLayout.getParent()); + warpsPerCTA = dotLayout.getWarpsPerCTA(); else if (auto sharedLayout = dyn_cast(layout)) llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr"); else @@ -2136,25 +2143,12 @@ unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand( SmallVector NvidiaMmaEncodingAttr::getShapePerCTATileForOperand( ArrayRef shape, int kWidth, int opIdx) const { assert(isAmpere() && "mmaLayout version = 1 is not implemented yet"); - auto parentShapePerCTATile = getShapePerCTATile(shape); - auto rank = parentShapePerCTATile.size(); + auto shapePerCTATile = getShapePerCTATile(shape); + auto rank = shapePerCTATile.size(); + auto kDim = opIdx == 0 ? rank - 1 : rank - 2; // 4 threads * 2 subtiles - unsigned kWidthTile = kWidth * 2 * 4; - if (opIdx == 0) { - if (rank == 2) - return {parentShapePerCTATile[rank - 2], kWidthTile}; - else - return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2], - kWidthTile}; - } else if (opIdx == 1) { - if (rank == 2) - return {kWidthTile, parentShapePerCTATile[rank - 1]}; - else - return {parentShapePerCTATile[0], kWidthTile, - parentShapePerCTATile[rank - 1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - } + shapePerCTATile[kDim] = kWidth * 2 * 4; + return shapePerCTATile; } SmallVector NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 6978ccfb2553..6796307b7e22 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -41,6 +41,17 @@ SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { return ret; } +// TODO Have order be a mandatory argument of standardOutDimNames. +SmallVector permuteDimNames(const SmallVector &names, + const SmallVector &order) { + assert(names.size() == order.size()); + SmallVector ret; + for (unsigned i : order) { + ret.push_back(names[i]); + } + return ret; +} + void assertIsRegisterLayout(const LinearLayout &layout) { assert(layout.getNumInDims() > 0); MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); @@ -281,15 +292,19 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); SmallVector dimNames = standardOutDimNames(ctx, rank); + auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma)); + // By using `reverse(dimNames)` below, we set the order to be row-major + assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, - llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); - - ctaLayout *= identityND( - S("warp"), mma.getWarpsPerCTA(), - llvm::to_vector(llvm::reverse(llvm::seq(rank))), dimNames); + ArrayRef(orderedDimNames).take_front(2)); + assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); + // FIXME(Lezcano). identityND should not have an `order` param as it's + // redundant with the order of the out dims. + ctaLayout *= + identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder(), dimNames); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); } @@ -322,10 +337,14 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, ctaLayout *= LinearLayout::identity1D(n / ctaLayout.getOutDimSize(S("dim1")), S("register"), S("dim1")); - // Expand the `warp` dimension according to warpsPerCTA. - // - // It's weird that this is order [0,1] when MMAv2's warpsPerCTA is [1,0], but - // this really does seem to be correct. + // The order given by choosing (`dim1`, `dim0`) is [1, 0], that is, N-major. + // Since the warpOrder needs to be M-major, we need to transpose the out + // dimensions AND transpose the order + // FIXME(Lezcano). identityND should not have an `order` param as it's + // redundant. The order is already given by the order of the + // out dims, and if it has an order, it shouldn't change the + // order of the out dims. + assert(getWarpOrder(mma) == SmallVector({0, 1})); ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}, {S("dim0"), S("dim1")}) .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); @@ -843,18 +862,24 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { LinearLayout ampereDotToLinearLayout(ArrayRef shape, DotOperandEncodingAttr dot) { - // TODO,BE. Implement ampereMMA in terms of this one + // Note that, even though MMAv2 looks similar to this layout, they are just + // the same at a register and lane level. The warps treatment is different! int rank = shape.size(); auto mma = cast(dot.getParent()); int kWidth = dot.getKWidth(); bool isA = dot.getOpIdx() == 0; - assert(mma.isAmpere()); assert((rank == 2 && mma.getInstrShape() == ArrayRef({16, 8})) || (rank == 3 && mma.getInstrShape() == ArrayRef({1, 16, 8}))); + assert(mma.isAmpere()); MLIRContext *ctx = mma.getContext(); - SmallVector dimNames = standardOutDimNames(ctx, rank); + // A and B have kMajor order + assert(getOrder(dot) == + getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); + + auto kMajorDims = + permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot)); // Implement A. For B transpose in the end std::vector> registers; @@ -881,24 +906,51 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, } registers.push_back({i, 0}); - if (!isA) { - for (auto &r : registers) { - std::swap(r[0], r[1]); + LinearLayout ctaLayout({{S("register"), registers}, {S("lane"), lanes}}, + ArrayRef(kMajorDims).take_front(2)); + + // Let warpsPerCTAMma = {2, 2}, then + // warpsPerCTA = {2, 1} for opA and warpsPerCTA = {1, 2} for opB + // assume warpOrder = {0, 1} + // Assume that C is tiled by 2x2 tiles. Since warpOrder={1, 0}, we have that + // the C is owned as per the following layout: + // C: 0 | 1 + // - | - + // 2 | 3 + // In order to be able to compute C, we need the following warp tiling of + // A and B: + // A: 0 1 | 0 1 B: 0 2 | 1 3 + // - - | - - - - | - - + // 2 3 | 2 3 0 2 | 1 3 + // In particular, for A and B we need to broadcast along K + + assert(mma.getWarpOrder() == getMatrixOrder(rank, /*rowMajor=*/true)); + auto warpsPerCTAMma = mma.getWarpsPerCTA(); + std::vector> warps; + if (isA) { + for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + warps.push_back({0, 0}); + } + for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + warps.push_back({0, i}); + } + } else { + for (int i = 1; i < warpsPerCTAMma[1]; i *= 2) { + warps.push_back({0, i}); } - for (auto &l : lanes) { - std::swap(l[0], l[1]); + for (int i = 1; i < warpsPerCTAMma[0]; i *= 2) { + warps.push_back({0, 0}); + } + } + if (rank == 3) { + for (auto &w : warps) { + w.push_back(0); } } - LinearLayout ctaLayout( - {{S("register"), registers}, {S("lane"), lanes}}, - llvm::to_vector(llvm::reverse(ArrayRef(dimNames).take_back(2)))); - - auto order = dot.getCTAOrder(); - assert(order[0] == rank - 1 && order[1] == rank - 2); - ctaLayout *= identityND(S("warp"), dot.getWarpsPerCTA(), order, dimNames); + ctaLayout *= LinearLayout({{S("warp"), warps}}, kMajorDims); - return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); + return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape); } std::optional @@ -907,7 +959,7 @@ DotOperandEncodingAttr::toLinearLayout(ArrayRef shape) const { if (auto mfmaLayout = llvm::dyn_cast(parent)) { return mfmaDotToLinearLayout(*this, shape); } else if (auto mma = mlir::dyn_cast(parent)) { - if (mma.getVersionMajor() == 2 && mma.getVersionMinor() == 0) { + if (mma.isAmpere()) { return ampereDotToLinearLayout(shape, *this); } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp index b03fb0989dda..508f03227cf6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -121,19 +121,15 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct( } if (dot.getOpIdx() == 1) { - // there are kWidth * 2 elems packed as bf16x2 int elemsInTile = dot.getKWidth(); - // n0 and n1 are unrolled in the legacy path - // Unrolling n1 makes some sense, but unrolling n0 makes absolutely no - // sense IMO + // n0 is unrolled in the legacy path, which makes no sense n0 *= 2; - n1 *= 2; for (auto b = 0; b < batch; ++b) - for (auto j = 0; j < n1 / elemsInTile; ++j) - for (auto i = 0; i < n0; ++i) - for (auto k = 0; k < elemsInTile; ++k) { - vals[{b, i, elemsInTile * j + k}] = elems[offset++]; - } + for (auto i = 0; i < n0; ++i) + for (auto j = 0; j < n1; ++j) { + vals[{b, i, 2 * j}] = elems[offset++]; + vals[{b, i, 2 * j + 1}] = elems[offset++]; + } return vals; } } diff --git a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp index d4c15bbad03f..d662537ed72d 100644 --- a/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp +++ b/unittest/Dialect/TritonGPU/LinearLayoutConversionsTest.cpp @@ -555,14 +555,14 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {2, 0}, {4, 0}, {32, 0}, + {64, 0}, {0, 8}, {0, 16}, - {0, 32}, - {64, 0}}}, + {0, 32}}}, {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, { S("warp"), - {}, + {{0, 0}, {0, 0}}, }, {S("block"), {}}, }, @@ -582,13 +582,46 @@ TEST_F(LinearLayoutConversionsTest, DotMMAv2_large_warp4_kwidth8) { {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, { S("warp"), - {}, + {{0, 0}, {0, 0}}, }, {S("block"), {}}, }, {S("dim0"), S("dim1")})); } +TEST_F(LinearLayoutConversionsTest, DotMMAv2_split_warp_kwidth8) { + EXPECT_EQ( + toLinearLayout({32, 64}, dotMMAv2(0, 8, {2, 2})), + LinearLayout({{S("register"), {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({64, 16}, dotMMAv2(1, 8, {2, 2})), + LinearLayout({{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ(toLinearLayout({64, 128}, dotMMAv2(0, 8, {2, 2})), + LinearLayout( + {{S("register"), + {{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {32, 0}}}, + {S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}}, + {S("warp"), {{0, 0}, {16, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + toLinearLayout({128, 32}, dotMMAv2(1, 8, {2, 2})), + LinearLayout( + {{S("register"), {{1, 0}, {2, 0}, {4, 0}, {32, 0}, {64, 0}, {0, 16}}}, + {S("lane"), {{8, 0}, {16, 0}, {0, 1}, {0, 2}, {0, 4}}}, + {S("warp"), {{0, 8}, {0, 0}}}, + {S("block"), {}}}, + {S("dim0"), S("dim1")})); +} + TEST_F(LinearLayoutConversionsTest, MFMA32_2x4Warps) { auto mfmaNT = mfma(/*warps=*/{2, 4}, /*mDim=*/32, /*nDim=*/32, /*isTransposed=*/false);