diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 6a820fc2b37a..e98de4ebabfa 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -267,24 +267,19 @@ SmallVector getWarpOrder(Attribute layout) { return getWarpOrder(dotLayout.getParent()); } } - auto order = getOrder(layout); - // FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's - // M-major This is awkward. Since we can choose any warpOrder in Ampere, we - // should probably choose M-major and change `LinearLayoutConversion.cpp` and - // `MMAv2.cpp` to match. - if (auto mmaLayout = dyn_cast(layout)) { - if (mmaLayout.isHopper()) { - // Hopper MMA instructions force warps to be column-major - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 - return getMatrixOrder(order.size(), /*rowMajor*/ false); - } + + auto nvidiaMma = dyn_cast(layout); + if (nvidiaMma && nvidiaMma.isHopper()) { + auto rank = nvidiaMma.getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ false); } else if (auto dotOpLayout = dyn_cast(layout)) { // It's quite weird to talk about warp order when that the warps // are broadcasted along the K dimension llvm::report_fatal_error( "DotOperandEncoding::getWarpOrder not implemented"); } - return order; + + return getOrder(layout); } SmallVector getOrder(Attribute layout) { @@ -293,7 +288,11 @@ SmallVector getOrder(Attribute layout) { } if (auto mmaLayout = dyn_cast(layout)) { // Order doesn't really matter. We just have to be consistent when unpacking - // the elements in the MMAv2/V3 lowerings. We choose row-major + // the output elements in the LLVM lowerings. We choose row-major + auto nvidiaMma = dyn_cast(layout); + if (nvidiaMma && nvidiaMma.isHopper()) { + llvm::report_fatal_error("Testing"); + } auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); return getMatrixOrder(rank, /*rowMajor*/ true); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 6796307b7e22..2668f384978e 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -327,6 +327,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256); assert(k == 8 || k == 16 || k == 32); + // TODO Make the getOrder of Hopper explicit here via an assert MLIRContext *ctx = mma.getContext(); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}},