From c726428005d554f029f128a0e7ba50fab5d4fc6f Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 6 Nov 2024 22:43:06 +0000 Subject: [PATCH] Consolidate `getOrder` as "element order" and document implicit "tile order" This partially reverts commit 38a11b859fff79ea214256d3f1cfe43d54e36c2c. It also documents that we are implicitly choosing a way to tile a full tensor depending on the layout. See https://github.com/triton-lang/triton/pull/5085#issuecomment-2460925683 --- include/triton/Analysis/Utility.h | 2 +- .../Conversion/TritonGPUToLLVM/Utility.h | 7 ++++--- include/triton/Dialect/TritonGPU/IR/Dialect.h | 5 ++--- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 8 ++++++++ lib/Analysis/Allocation.cpp | 8 ++------ lib/Analysis/AxisInfo.cpp | 6 +++--- lib/Analysis/Utility.cpp | 10 +++++----- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 4 ++-- .../TritonGPUToLLVM/ScanOpToLLVM.cpp | 6 +++--- .../TritonGPUToLLVM/ViewOpToLLVM.cpp | 6 +++--- lib/Dialect/TritonGPU/IR/Dialect.cpp | 4 ++-- .../TritonGPU/IR/LinearLayoutConversions.cpp | 20 ++++++++++--------- .../Transforms/ReduceDataDuplication.cpp | 10 +++++----- .../DecomposeUnsupportedConversions.cpp | 2 +- 14 files changed, 52 insertions(+), 46 deletions(-) diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index 20da9784495d..df6029db0de2 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -66,7 +66,7 @@ class ReduceOpHelper { // The shape of the shared memory space needed for the reduction. SmallVector getScratchRepShape(); - SmallVector getThreadOrderWithAxisAtBeginning(); + SmallVector getOrderWithAxisAtBeginning(); unsigned getScratchSizeInBytes(); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 35f1303fa1cf..253033e98e8c 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -466,15 +466,16 @@ emitBaseIndexWithinCTAForBlockedLayout(Location loc, RewriterBase &rewriter, auto sizePerThread = blockedLayout.getSizePerThread(); auto threadsPerWarp = blockedLayout.getThreadsPerWarp(); auto warpsPerCTA = blockedLayout.getWarpsPerCTA(); - auto order = blockedLayout.getOrder(); + auto threadOrder = blockedLayout.getThreadOrder(); + auto warpOrder = blockedLayout.getWarpOrder(); auto shapePerCTA = triton::gpu::getShapePerCTA(blockedLayout, shape); unsigned rank = shape.size(); // delinearize threadId to get the base index SmallVector multiDimWarpId = - delinearize(rewriter, loc, warpId, warpsPerCTA, order); + delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); SmallVector multiDimThreadId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); SmallVector multiDimBase(rank); for (unsigned k = 0; k < rank; ++k) { diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index cfc00926ddc2..a9b49448c1d0 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -76,9 +76,8 @@ SmallVector getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); // Returns the dimensions of the tensor from minor (fast-varying) to -// major (slow-varying). For blocked, mma, and dotOperand layouts, -// though the elements are in registers, the order refers to memory -// layout of the original tensor in global memory. +// major (slow-varying). For distributed layouts, this represents +// the order of the elements within a thread. // For shared Layout, the order refers to which dimension of the original tensor // is contiguous in shared memory. SmallVector getOrder(Attribute layout); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 8f9a1a850fd5..c3625aef73fb 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -474,7 +474,15 @@ layout = [0 4 8 12] [3 7 11 15] For the Threads Per Warp and Values Per Thread level, the linear id distribution is variant for each sub-class encoding. + +If the layout does not completely cover the tensor, we tile it until we cover the entire tensor. +The order of this tiling is not defined explicitly in the language at the time of this writing. +In most cases, it is the same as the order of the elements, but there are exceptions like }]; +// FIXME: See the sentence above and the comment in +// https://github.com/triton-lang/triton/pull/5089#discussion_r1832243393 +// We need to add a getTilingOrder method to this class that returns the order used +// to tile the tensor when the layout is not large enough. let methods = [ // Interface for the meta information about the multiple thread hierarchy. diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 020a9ea4d3bc..131c1ff67e84 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -84,12 +84,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy, assert(cvtNeedsSharedMemory(srcTy, dstTy)); - // FIXME This is NOT entirely correct - // This should be getElemOrder, but we don't have such a method - // TODO Implement getElemOrder and make sure it's consistent with - // getContigPerThread - auto inOrd = gpu::getThreadOrder(srcLayout); - auto outOrd = gpu::getThreadOrder(dstLayout); + auto inOrd = gpu::getOrder(srcLayout); + auto outOrd = gpu::getOrder(dstLayout); scratchConfig.order = outOrd; unsigned srcContigPerThread = diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index d2c2c9fd8da3..f0c5ae3167ec 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { // Here order should be ordered by contiguous first, so the first element // should have the largest contiguous. - auto order = triton::gpu::getThreadOrder(layout); + auto order = triton::gpu::getOrder(layout); unsigned align = getPtrAlignment(ptr); auto uniqueContigPerThread = @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { if (!axisInfo) return 1; auto layout = tensorTy.getEncoding(); - auto order = triton::gpu::getThreadOrder(layout); + auto order = triton::gpu::getOrder(layout); auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); auto maxContig = axisInfo->getContiguity(order[0]); auto elemNumBits = triton::getPointeeBitWidth(tensorTy); @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { auto *axisInfo = getAxisInfo(mask); if (!axisInfo) return 1; - auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding()); + auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " << alignment); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 501e19722089..ac72b4f26cd6 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) { return axis; } -SmallVector getParentThreadOrder(Attribute layout) { +SmallVector getParentOrder(Attribute layout) { if (auto sliceEncoding = mlir::dyn_cast(layout)) { - return getParentThreadOrder(sliceEncoding.getParent()); + return getParentOrder(sliceEncoding.getParent()); } return getThreadOrder(layout); } @@ -44,12 +44,12 @@ SmallVector getParentThreadOrder(Attribute layout) { // TODO(jlebar): Move this class into namespace triton. bool ReduceOpHelper::isReductionOnLayoutFastAxis() { return getParentAxis(getSrcLayout(), axis) == - getParentThreadOrder(getSrcLayout())[0]; + getParentOrder(getSrcLayout())[0]; } -SmallVector ReduceOpHelper::getThreadOrderWithAxisAtBeginning() { +SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { auto srcLayout = getSrcLayout(); - auto order = getThreadOrder(srcLayout); + auto order = getOrder(srcLayout); auto it = std::find(order.begin(), order.end(), axis); // delete the axis from order order.erase(it); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 829d4e7104f0..26dc8a537973 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -322,7 +322,7 @@ struct ReduceOpConversion getMultiDimWarpId(helper, warpId, loc, rewriter); Value warpIdAxis = multiDimWarpId[axis]; - auto smemOrder = helper.getThreadOrderWithAxisAtBeginning(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); for (auto it : accs) { const SmallVector &key = it.first; SmallVector &acc = it.second; @@ -409,7 +409,7 @@ struct ReduceOpConversion Location loc = op.getLoc(); auto srcLayout = helper.getSrcLayout(); auto axis = op.getAxis(); - auto smemOrder = helper.getThreadOrderWithAxisAtBeginning(); + auto smemOrder = helper.getOrderWithAxisAtBeginning(); SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 969b227c8dda..64e6ca787780 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -389,10 +389,10 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); - auto order = triton::gpu::getOrder(srcEncoding); + auto threadOrder = triton::gpu::getThreadOrder(srcEncoding); auto warpOrder = triton::gpu::getWarpOrder(srcEncoding); SmallVector multiDimLaneId = - delinearize(rewriter, loc, laneId, threadsPerWarp, order); + delinearize(rewriter, loc, laneId, threadsPerWarp, threadOrder); SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, warpOrder); @@ -402,7 +402,7 @@ ScanOpConversion::getDelinearizedIds(ConversionPatternRewriter &rewriter, multiDimLaneId[axis] = i32_val(0); threadsPerWarp[axis] = 1; Value laneIdParallel = - linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, order); + linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, threadOrder); multiDimWarpId[axis] = i32_val(0); warpsPerCTA[axis] = 1; Value warpIdParallel = diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index be5feedd67bf..8ba0fd3356f6 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern { int numContiguousValues = 1; auto encoding = cast( cast(op.getSrc().getType()).getEncoding()); - int splitDim = encoding.getThreadOrder().size() - 1; - for (int i = 0; i < encoding.getThreadOrder().size(); i++) { - if (encoding.getThreadOrder()[i] == splitDim) + int splitDim = encoding.getOrder().size() - 1; + for (int i = 0; i < encoding.getOrder().size(); i++) { + if (encoding.getOrder()[i] == splitDim) break; numContiguousValues *= encoding.getSizePerThread()[i]; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 8462c24aea67..414141596c17 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -269,13 +269,13 @@ SmallVector getWarpOrder(Attribute layout) { return {}; } +// Returns the order of the elements in a layout from the fastest running +// dimension to the slowest SmallVector getOrder(Attribute layout) { if (auto blockedLayout = dyn_cast(layout)) { return llvm::to_vector(blockedLayout.getOrder()); } if (auto mmaLayout = dyn_cast(layout)) { - // Order doesn't really matter. We just have to be consistent when unpacking - // the output elements in the LLVM lowerings. We choose row-major auto distributedLayout = cast(layout); auto rank = distributedLayout.getWarpsPerCTA().size(); return getMatrixOrder(rank, /*rowMajor*/ true); diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 2668f384978e..eb605baf89e1 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -292,9 +292,9 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); SmallVector dimNames = standardOutDimNames(ctx, rank); - auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma)); - // By using `reverse(dimNames)` below, we set the order to be row-major - assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); + // Mma is tiled in a row-major fashion + auto orderedDimNames = + permuteDimNames(dimNames, getMatrixOrder(rank, /*rowMajor=*/true)); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, @@ -327,7 +327,6 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, assert(n == 8 || n == 16 || n == 32 || n == 64 || n == 128 || n == 256); assert(k == 8 || k == 16 || k == 32); - // TODO Make the getOrder of Hopper explicit here via an assert MLIRContext *ctx = mma.getContext(); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, @@ -875,14 +874,17 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, assert(mma.isAmpere()); MLIRContext *ctx = mma.getContext(); - // A and B have kMajor order - assert(getOrder(dot) == - getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); + + // The A and B operands are tiled in a kMajor fashion + auto kMajorOrder = + getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true); auto kMajorDims = - permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot)); + permuteDimNames(standardOutDimNames(ctx, rank), kMajorOrder); + // This agrees with the order of the elements, which means that we can share + // the code below for both A and B without having to perform any swaps + assert(getOrder(dot) == kMajorOrder); - // Implement A. For B transpose in the end std::vector> registers; std::vector> lanes; int32_t i = 1; diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp index dce6c7f2af1b..b1e296c1bbe4 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp @@ -44,17 +44,17 @@ class TritonGPUReduceDataDuplicationPass return; if (!cvtNeedsSharedMemory(srcType, dstType)) return; - auto srcThreadOrder = triton::gpu::getThreadOrder(srcEncoding); - auto rank = srcThreadOrder.size(); + auto srcOrder = triton::gpu::getOrder(srcEncoding); + auto rank = srcOrder.size(); SmallVector sharedOrder; if (rank == 3) { // add all elements except the element that is zero for (unsigned i = 0; i < rank; ++i) - if (srcThreadOrder[i] != 0) - sharedOrder.emplace_back(srcThreadOrder[i]); + if (srcOrder[i] != 0) + sharedOrder.emplace_back(srcOrder[i]); sharedOrder.emplace_back(0); } else { - sharedOrder = srcThreadOrder; + sharedOrder = srcOrder; } auto sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp index 1aa2b516a559..40cb55bbc00d 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -54,7 +54,7 @@ class DecomposeLocalLoadToDotOperand type.getShape(), type.getElementType(), triton::gpu::SharedEncodingAttr::get( op.getContext(), dstDotOp, type.getShape(), - triton::gpu::getThreadOrder(parentEnc), + triton::gpu::getOrder(parentEnc), triton::gpu::getCTALayout(parentEnc), type.getElementType()), srcType.getMemorySpace()); auto tmp = rewriter.create(