diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index df6029db0de2..20da9784495d 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -66,7 +66,7 @@ class ReduceOpHelper { // The shape of the shared memory space needed for the reduction. SmallVector getScratchRepShape(); - SmallVector getOrderWithAxisAtBeginning(); + SmallVector getThreadOrderWithAxisAtBeginning(); unsigned getScratchSizeInBytes(); diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index e2713916ea7e..020a9ea4d3bc 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -46,6 +46,9 @@ static SmallVector getRepShapeForCvt(RankedTensorType srcTy, auto dstShapePerCTATile = gpu::getShapePerCTATile(dstLayout, dstTy.getShape()); + assert(srcTy.getRank() == dstTy.getRank() && + "src and dst must have the same rank"); + unsigned rank = dstTy.getRank(); SmallVector repShape(rank); for (unsigned d = 0; d < rank; ++d) { diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index f0c5ae3167ec..d2c2c9fd8da3 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1213,7 +1213,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { // Here order should be ordered by contiguous first, so the first element // should have the largest contiguous. - auto order = triton::gpu::getOrder(layout); + auto order = triton::gpu::getThreadOrder(layout); unsigned align = getPtrAlignment(ptr); auto uniqueContigPerThread = @@ -1235,7 +1235,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { if (!axisInfo) return 1; auto layout = tensorTy.getEncoding(); - auto order = triton::gpu::getOrder(layout); + auto order = triton::gpu::getThreadOrder(layout); auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); auto maxContig = axisInfo->getContiguity(order[0]); auto elemNumBits = triton::getPointeeBitWidth(tensorTy); @@ -1262,7 +1262,7 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { auto *axisInfo = getAxisInfo(mask); if (!axisInfo) return 1; - auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); + auto maskOrder = triton::gpu::getThreadOrder(tensorTy.getEncoding()); auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment " << alignment); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index aa9f8b01eae1..8c62e738f764 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -32,9 +32,9 @@ int getParentAxis(Attribute layout, int axis) { return axis; } -SmallVector getParentOrder(Attribute layout) { +SmallVector getParentThreadOrder(Attribute layout) { if (auto sliceEncoding = mlir::dyn_cast(layout)) { - return getParentOrder(sliceEncoding.getParent()); + return getParentThreadOrder(sliceEncoding.getParent()); } return getThreadOrder(layout); } @@ -44,12 +44,12 @@ SmallVector getParentOrder(Attribute layout) { // TODO(jlebar): Move this class into namespace triton. bool ReduceOpHelper::isReductionOnLayoutFastAxis() { return getParentAxis(getSrcLayout(), axis) == - getParentOrder(getSrcLayout())[0]; + getParentThreadOrder(getSrcLayout())[0]; } -SmallVector ReduceOpHelper::getOrderWithAxisAtBeginning() { +SmallVector ReduceOpHelper::getThreadOrderWithAxisAtBeginning() { auto srcLayout = getSrcLayout(); - auto order = getOrder(srcLayout); + auto order = getThreadOrder(srcLayout); auto it = std::find(order.begin(), order.end(), axis); // delete the axis from order order.erase(it); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 8682706db899..5367ab89f713 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -283,7 +283,7 @@ struct ReduceOpConversion getMultiDimWarpId(helper, warpId, loc, rewriter); Value warpIdAxis = multiDimWarpId[axis]; - auto smemOrder = helper.getOrderWithAxisAtBeginning(); + auto smemOrder = helper.getThreadOrderWithAxisAtBeginning(); for (auto it : accs) { const SmallVector &key = it.first; SmallVector &acc = it.second; @@ -370,7 +370,7 @@ struct ReduceOpConversion Location loc = op.getLoc(); auto srcLayout = helper.getSrcLayout(); auto axis = op.getAxis(); - auto smemOrder = helper.getOrderWithAxisAtBeginning(); + auto smemOrder = helper.getThreadOrderWithAxisAtBeginning(); SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemTy = getElementType(op, i); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 297a94e851f6..be5feedd67bf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -181,9 +181,9 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern { int numContiguousValues = 1; auto encoding = cast( cast(op.getSrc().getType()).getEncoding()); - int splitDim = encoding.getOrder().size() - 1; - for (int i = 0; i < encoding.getOrder().size(); i++) { - if (encoding.getOrder()[i] == splitDim) + int splitDim = encoding.getThreadOrder().size() - 1; + for (int i = 0; i < encoding.getThreadOrder().size(); i++) { + if (encoding.getThreadOrder()[i] == splitDim) break; numContiguousValues *= encoding.getSizePerThread()[i]; } @@ -336,7 +336,6 @@ struct BroadcastOpConversion unsigned rank = srcTy.getRank(); auto typeConverter = getTypeConverter(); assert(rank == resultTy.getRank()); - auto order = triton::gpu::getOrder(srcLayout); auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); SmallVector srcVals = unpackLLElements(loc, src, rewriter); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 2bda9c586e44..9e058347f750 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -217,7 +217,7 @@ bool isExpensiveView(Type srcType, Type dstType) { return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType); } -/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr. +/* Utility function used by get.*Order methods of SliceEncodingAttr. * Erase dim and decrease all values larger than dim by 1. * Example: order = [0, 2, 4, 3, 1], dim = 2 * resOrder = [0, 3, 2, 1] @@ -262,24 +262,11 @@ SmallVector getOrderForDotOperand(unsigned opIdx, unsigned rank, } SmallVector getWarpOrder(Attribute layout) { - if (auto dotLayout = dyn_cast(layout)) { - if (isa(dotLayout.getParent())) { - return getWarpOrder(dotLayout.getParent()); - } - } - - auto nvidiaMma = dyn_cast(layout); - if (nvidiaMma && nvidiaMma.isHopper()) { - auto rank = nvidiaMma.getWarpsPerCTA().size(); - return getMatrixOrder(rank, /*rowMajor*/ false); - } else if (auto dotOpLayout = dyn_cast(layout)) { - // It's quite weird to talk about warp order when that the warps - // are broadcasted along the K dimension - llvm::report_fatal_error( - "DotOperandEncoding::getWarpOrder not implemented"); - } - - return getOrder(layout); + if (auto distributedLayout = mlir::dyn_cast(layout)) + return distributedLayout.getWarpOrder(); + else + llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); + return {}; } SmallVector getOrder(Attribute layout) { @@ -317,7 +304,7 @@ SmallVector getOrder(Attribute layout) { llvm::report_fatal_error("Unimplemented usage of getOrder"); return {}; -}; +} SmallVector getThreadOrder(Attribute layout) { if (auto distributedLayout = mlir::dyn_cast(layout)) @@ -325,7 +312,7 @@ SmallVector getThreadOrder(Attribute layout) { else llvm::report_fatal_error("Unimplemented usage of getThreadOrder"); return {}; -}; +} CTALayoutAttr getCTALayout(Attribute layout) { if (auto distributedLayout = @@ -768,7 +755,8 @@ SmallVector SliceEncodingAttr::getWarpsPerCTA() const { return warpsPerCTA; } SmallVector SliceEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + auto parentWarpOrder = ::getWarpOrder(getParent()); + return eraseOrder(parentWarpOrder, getDim()); } SmallVector SliceEncodingAttr::getThreadsPerWarp() const { auto parent = getParent(); @@ -780,7 +768,8 @@ SmallVector SliceEncodingAttr::getThreadsPerWarp() const { return threadsPerWarp; } SmallVector SliceEncodingAttr::getThreadOrder() const { - return ::getOrder(*this); + auto parentThreadOrder = ::getThreadOrder(getParent()); + return eraseOrder(parentThreadOrder, getDim()); } SmallVector SliceEncodingAttr::getSizePerThread() const { auto sizePerThread = ::getSizePerThread(getParent()); @@ -1048,7 +1037,14 @@ SmallVector DotOperandEncodingAttr::getWarpsPerCTA() const { return warps; } SmallVector DotOperandEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + // FIXME(Lezcano): Preexisting. Do we want to have this path at all? + if (mlir::isa(getParent())) { + return ::getWarpOrder(getParent()); + } + // It's quite weird to talk about warp order when that the warps + // are broadcasted along the K dimension + llvm::report_fatal_error("DotOperandEncoding::getWarpOrder not implemented"); + return {}; } SmallVector DotOperandEncodingAttr::getThreadOrder() const { return getOrderForDotOperand(getOpIdx(), getWarpsPerCTA().size(), @@ -1596,7 +1592,7 @@ SmallVector AMDMfmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector AMDMfmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + return ::getOrder(*this); } SmallVector AMDMfmaEncodingAttr::getThreadOrder() const { auto order = ::getOrder(*this); @@ -1765,7 +1761,7 @@ SmallVector AMDWmmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector AMDWmmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + return ::getOrder(*this); } SmallVector AMDWmmaEncodingAttr::getThreadOrder() const { return ::getOrder(*this); @@ -1889,7 +1885,11 @@ SmallVector NvidiaMmaEncodingAttr::getWarpsPerCTA() const { return SmallVector(getWarpsPerCTA__()); } SmallVector NvidiaMmaEncodingAttr::getWarpOrder() const { - return ::getWarpOrder(*this); + auto rank = getWarpsPerCTA().size(); + // Hopper (wgmma) uses column-major as this is embeded in the instruction + // For Ampere we can choose either row-major or column-major. + // We choose row-major as the legacy path did so + return getMatrixOrder(rank, /*rowMajor*/ !isHopper()); } SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { auto rank = getWarpsPerCTA().size(); @@ -1913,11 +1913,11 @@ SmallVector NvidiaMmaEncodingAttr::getThreadsPerWarp() const { "getThreadsPerWarp not implemented for unknown Mma version "); } SmallVector NvidiaMmaEncodingAttr::getThreadOrder() const { - auto rank = getThreadsPerWarp().size(); + auto rank = getWarpsPerCTA().size(); return getMatrixOrder(rank, /*rowMajor*/ true); } SmallVector NvidiaMmaEncodingAttr::getSizePerThread() const { - auto rank = ::getOrder(*this).size(); + auto rank = getWarpsPerCTA().size(); SmallVector res(rank, 1); if (isAmpere()) { res[rank - 2] = 2; @@ -2158,11 +2158,10 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const { if (opIdx == 0) { sizePerThread[rank - 2] = 2; sizePerThread[rank - 1] = 2 * kWidth; - } else if (opIdx == 1) { + } else { + assert(opIdx == 1); sizePerThread[rank - 2] = 2 * kWidth; sizePerThread[rank - 1] = 1; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } return sizePerThread; }