From 832c369a78ac263e9717e5e83f9c52ce13e6e782 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Tue, 5 Nov 2024 08:29:48 +0000 Subject: [PATCH] [BACKEND] Fix uses of getOrder(DotOperand(Nvidia) and MMA(Nvidia)) (#5055) We use `getOrder` very liberally throughout the codebase, when we really meant to use `getThreadOrder`. This is an issue with the input layout is an `DotOperand(mma(opIdx=1))`, where the thread order and the matrix order are opposite. Found this to be an issue when a PR changed the `getOrder` of `DotOperand(Hopper)` to an incorrect one and CI still passed! The issue here is that the LLVM lowering for wgmma and the LinearLayout does not use `getOrder`, but there are many other subsystems do, and many heuristics would be getting an incorrect order, and potentially be disabled. This is particularly problematic for `DotOperand(opIdx=1)` in nvidia hardware, as `getThreadOrder` and `getOrder` are different! While doing so we: - Audit most (all?) the calls to `getOrder(dotOperand)`. It turns out that most of them really meant `getThreadOrder` - Fix the ordering methods of `SliceEncodingAttr` to be consistent - Move the implementation of `getWarpOrder` to the Attr classes, because of OOP The test strategy was to add `llvm::report_fatal_error("Testing");` within `getOrder(nvidiaMma)` and `getOrder(DotOperand(nvidiaMma))` and triaging all errors that were raised in CI. (cherry picked from commit 38a11b859fff79ea214256d3f1cfe43d54e36c2c) --- include/triton/Analysis/Utility.h | 2 +- lib/Analysis/Allocation.cpp | 3 + lib/Analysis/AxisInfo.cpp | 6 +- lib/Analysis/Utility.cpp | 10 +-- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 4 +- .../TritonGPUToLLVM/ViewOpToLLVM.cpp | 7 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 71 +++++++++---------- .../TritonGPU/IR/LinearLayoutConversions.cpp | 1 + .../Transforms/ReduceDataDuplication.cpp | 10 +-- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 7 +- .../DecomposeUnsupportedConversions.cpp | 2 +- .../LoadStoreOpToLLVM.cpp | 4 +- 12 files changed, 63 insertions(+), 64 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 37d24ac929a9..34c612302ebc 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -66,7 +66,7 @@ class ReduceOpHelper { // The shape of the shared memory space needed for the reduction. SmallVector getScratchRepShape(); - SmallVector getOrderWithAxisAtBeginning(); + SmallVector getThreadOrderWithAxisAtBeginning(); unsigned getScratchSizeInBytes(); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 276a6e7004df..b8ea1116e632 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -80,6 +80,9 @@ static SmallVector getRepShapeForCvt(RankedTensorType srcTy, auto srcShapePerCTATile = getShapePerCTATile(srcLayout, srcTy.getShape()); auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); + assert(srcTy.getRank() == dstTy.getRank() && + "src and dst must have the same rank"); + unsigned rank = dstTy.getRank(); SmallVector repShape(rank); for (unsigned d = 0; d < rank; ++d) { diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 717df8d1bd5a..230990ee742e 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { // Here order should be ordered by contiguous first, so the first element // should have the largest contiguous. - auto order = triton::gpu::getOrder(layout); + auto order = triton::gpu::getThreadOrder(layout); unsigned align = getPtrAlignment(ptr); auto uniqueContigPerThread = @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { if (!axisInfo) return 1; auto layout = tensorTy.getEncoding(); - auto order = triton::gpu::getOrder(layout); + auto order = triton::gpu::getThreadOrder(layout); auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); auto maxContig = axisInfo->getContiguity(order[0]); auto elemNumBits = triton::getPointeeBitWidth(tensorTy); @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { auto *axisInfo = getAxisInfo(mask); if (!axisInfo) return 1; - auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); + auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding()); auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " << alignment); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 30ba11c31782..13d5819612bc 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) { return axis; } -SmallVector getParentOrder(Attribute layout) { +SmallVector getParentThreadOrder(Attribute layout) { if (auto sliceEncoding = mlir::dyn_cast(layout)) { - return getParentOrder(sliceEncoding.getParent()); + return getParentThreadOrder(sliceEncoding.getParent()); } return getThreadOrder(layout); } @@ -44,12 +44,12 @@ SmallVector getParentOrder(Attribute layout) { // TODO(jlebar): Move this class into namespace triton. bool ReduceOpHelper::isReductionOnLayoutFastAxis() { return getParentAxis(getSrcLayout(), axis) == - getParentOrder(getSrcLayout())[0]; + getParentThreadOrder(getSrcLayout())[0]; } -SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { +SmallVector ReduceOpHelper::getThreadOrderWithAxisAtBeginning() { auto srcLayout = getSrcLayout(); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); auto it = std::find(order.begin(), order.end(), axis); // delete the axis from order order.erase(it); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 966f6d31c725..4e8053923399 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -276,7 +276,7 @@ struct ReduceOpConversion getMultiDimWarpId(helper, warpId, loc, rewriter); Value warpIdAxis = multiDimWarpId[axis]; - auto smemOrder = helper.getOrderWithAxisAtBeginning(); + auto smemOrder = helper.getThreadOrderWithAxisAtBeginning(); for (auto it : accs) { const SmallVector &key = it.first; SmallVector &acc = it.second; @@ -363,7 +363,7 @@ struct ReduceOpConversion Location loc = op.getLoc(); auto srcLayout = helper.getSrcLayout(); auto axis = op.getAxis(); - auto smemOrder = helper.getOrderWithAxisAtBeginning(); + auto smemOrder = helper.getThreadOrderWithAxisAtBeginning(); SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 297a94e851f6..be5feedd67bf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern { int numContiguousValues = 1; auto encoding = cast( cast(op.getSrc().getType()).getEncoding()); - int splitDim = encoding.getOrder().size() - 1; - for (int i = 0; i < encoding.getOrder().size(); i++) { - if (encoding.getOrder()[i] == splitDim) + int splitDim = encoding.getThreadOrder().size() - 1; + for (int i = 0; i < encoding.getThreadOrder().size(); i++) { + if (encoding.getThreadOrder()[i] == splitDim) break; numContiguousValues *= encoding.getSizePerThread()[i]; } @@ -336,7 +336,6 @@ struct BroadcastOpConversion unsigned rank = srcTy.getRank(); auto typeConverter = getTypeConverter(); assert(rank == resultTy.getRank()); - auto order = triton::gpu::getOrder(srcLayout); auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); SmallVector srcVals = unpackLLElements(loc, src, rewriter); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 6a820fc2b37a..d0365b4cee37 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -217,7 +217,7 @@ bool isExpensiveView(Type srcType, Type dstType) { return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); } -/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. +/* Utility function used by get.*Order methods of SliceEncodingAttr. * Erase dim and decrease all values larger than dim by 1. * Example: order = [0, 2, 4, 3, 1], dim = 2 * resOrder = [0, 3, 2, 1] @@ -262,29 +262,11 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, } SmallVector getWarpOrder(Attribute layout) { - if (auto dotLayout = dyn_cast(layout)) { - if (isa(dotLayout.getParent())) { - return getWarpOrder(dotLayout.getParent()); - } - } - auto order = getOrder(layout); - // FIXME: At the moment, warpOrder in Ampere is N-major but in Hopper it's - // M-major This is awkward. Since we can choose any warpOrder in Ampere, we - // should probably choose M-major and change `LinearLayoutConversion.cpp` and - // `MMAv2.cpp` to match. - if (auto mmaLayout = dyn_cast(layout)) { - if (mmaLayout.isHopper()) { - // Hopper MMA instructions force warps to be column-major - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-wgmma-mma-async-m64nnk8 - return getMatrixOrder(order.size(), /*rowMajor*/ false); - } - } else if (auto dotOpLayout = dyn_cast(layout)) { - // It's quite weird to talk about warp order when that the warps - // are broadcasted along the K dimension - llvm::report_fatal_error( - "DotOperandEncoding::getWarpOrder not implemented"); - } - return order; + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getWarpOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); + return {}; } SmallVector getOrder(Attribute layout) { @@ -293,7 +275,7 @@ SmallVector getOrder(Attribute layout) { } if (auto mmaLayout = dyn_cast(layout)) { // Order doesn't really matter. We just have to be consistent when unpacking - // the elements in the MMAv2/V3 lowerings. We choose row-major + // the output elements in the LLVM lowerings. We choose row-major auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); return getMatrixOrder(rank, /*rowMajor*/ true); @@ -318,7 +300,7 @@ SmallVector getOrder(Attribute layout) { llvm::report_fatal_error("Unimplemented usage of getOrder"); return {}; -}; +} SmallVector getThreadOrder(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) @@ -326,7 +308,7 @@ SmallVector getThreadOrder(Attribute layout) { else llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); return {}; -}; +} CTALayoutAttr getCTALayout(Attribute layout) { if (auto distributedLayout = @@ -769,7 +751,8 @@ SmallVector SliceEncodingAttr::getWarpsPerCTA() const { return warpsPerCTA; } SmallVector SliceEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + auto parentWarpOrder = ::getWarpOrder(getParent()); + return eraseOrder(parentWarpOrder, getDim()); } SmallVector SliceEncodingAttr::getThreadsPerWarp() const { auto parent = getParent(); @@ -781,7 +764,8 @@ SmallVector SliceEncodingAttr::getThreadsPerWarp() const { return threadsPerWarp; } SmallVector SliceEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto parentThreadOrder = ::getThreadOrder(getParent()); + return eraseOrder(parentThreadOrder, getDim()); } SmallVector SliceEncodingAttr::getSizePerThread() const { auto sizePerThread = ::getSizePerThread(getParent()); @@ -1049,7 +1033,14 @@ SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { return warps; } SmallVector DotOperandEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + // FIXME(Lezcano): Preexisting. Do we want to have this path at all? + if (mlir::isa(getParent())) { + return ::getWarpOrder(getParent()); + } + // It's quite weird to talk about warp order when that the warps + // are broadcasted along the K dimension + llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented"); + return {}; } SmallVector DotOperandEncodingAttr::getThreadOrder() const { return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), @@ -1597,7 +1588,7 @@ SmallVector AMDMfmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + return ::getOrder(*this); } SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { auto order = ::getOrder(*this); @@ -1766,7 +1757,7 @@ SmallVector AMDWmmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector AMDWmmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + return ::getOrder(*this); } SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { return ::getOrder(*this); @@ -1890,7 +1881,11 @@ SmallVector NvidiaMmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector NvidiaMmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + auto rank = getWarpsPerCTA().size(); + // Hopper (wgmma) uses column-major as this is embeded in the instruction + // For Ampere we can choose either row-major or column-major. + // We choose row-major as the legacy path did so + return getMatrixOrder(rank, /*rowMajor*/ !isHopper()); } SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { auto rank = getWarpsPerCTA().size(); @@ -1914,10 +1909,11 @@ SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { "getThreadsPerWarp not implemented for unknown Mma version "); } SmallVector NvidiaMmaEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto rank = getWarpsPerCTA().size(); + return getMatrixOrder(rank, /*rowMajor*/ true); } SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { - auto rank = ::getOrder(*this).size(); + auto rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); if (isAmpere()) { res[rank - 2] = 2; @@ -2158,11 +2154,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { if (opIdx == 0) { sizePerThread[rank - 2] = 2; sizePerThread[rank - 1] = 2 * kWidth; - } else if (opIdx == 1) { + } else { + assert(opIdx == 1); sizePerThread[rank - 2] = 2 * kWidth; sizePerThread[rank - 1] = 1; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } return sizePerThread; } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 6796307b7e22..2668f384978e 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -327,6 +327,7 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256); assert(k == 8 || k == 16 || k == 32); + // TODO Make the getOrder of Hopper explicit here via an assert MLIRContext *ctx = mma.getContext(); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index b1e296c1bbe4..dce6c7f2af1b 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -44,17 +44,17 @@ class TritonGPUReduceDataDuplicationPass return; if (!cvtNeedsSharedMemory(srcType, dstType)) return; - auto srcOrder = triton::gpu::getOrder(srcEncoding); - auto rank = srcOrder.size(); + auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding); + auto rank = srcThreadOrder.size(); SmallVector sharedOrder; if (rank == 3) { // add all elements except the element that is zero for (unsigned i = 0; i < rank; ++i) - if (srcOrder[i] != 0) - sharedOrder.emplace_back(srcOrder[i]); + if (srcThreadOrder[i] != 0) + sharedOrder.emplace_back(srcThreadOrder[i]); sharedOrder.emplace_back(0); } else { - sharedOrder = srcOrder; + sharedOrder = srcThreadOrder; } auto sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 5265f631ad9e..1abacc4cc39b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -39,15 +39,16 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, auto sizePerThread = triton::gpu::getSizePerThread(layout); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); - auto order = triton::gpu::getOrder(layout); + auto threadOrder = triton::gpu::getThreadOrder(layout); + auto warpOrder = triton::gpu::getWarpOrder(layout); auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); Value warpSize = i32_val(triton::gpu::getWarpSize(layout)); Value laneId = urem(tid, warpSize); Value warpId = udiv(tid, warpSize); SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); for (unsigned dim = 0; dim < rank; ++dim) { // if there is no data replication across threads on this dimension if (shape[dim] >= shapePerCTATile[dim]) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 36b14e270b27..6129d77f174c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -54,7 +54,7 @@ class DecomposeLocalLoadToDotOperand type.getShape(), type.getElementType(), triton::gpu::SharedEncodingAttr::get( op.getContext(), dstDotOp, type.getShape(), - triton::gpu::getOrder(parentEnc), + triton::gpu::getThreadOrder(parentEnc), triton::gpu::getCTALayout(parentEnc), type.getElementType()), srcType.getMemorySpace()); auto tmp = rewriter.create( diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index cb430d8fadef..db6a6c64a167 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -38,7 +38,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, auto sizePerThread = triton::gpu::getSizePerThread(layout); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); - auto order = triton::gpu::getOrder(layout); + auto threadOrder = triton::gpu::getThreadOrder(layout); auto warpOrder = triton::gpu::getWarpOrder(layout); auto shapePerCTATile = triton::gpu::getShapePerCTATile(layout, shape); Value warpSize = i32_val(32); @@ -47,7 +47,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter, SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); for (unsigned dim = 0; dim < rank; ++dim) { // if there is no data replication across threads on this dimension if (shape[dim] >= shapePerCTATile[dim])