diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index b707d8f7d328..214707f7eaa9 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -32,7 +32,7 @@ namespace { #define S(v) StringAttr::get(ctx, (v)) -// Returns ["out0", "out1", ..., "out"]. +// Returns ["dim0", "dim1", ..., "dim"]. SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { SmallVector ret; for (int i = 0; i < rank; i++) { @@ -71,14 +71,18 @@ void assertIsRegisterLayout(const LinearLayout &layout) { expectedOuts.end())); } -// Returns a 1D -> ND layout that's equivalent to creating a 1D -> 1D mapping of -// size product(shape) and then reshaping to permute(shape, order). -LinearLayout identityND(StringAttr inDimName, ArrayRef shape, - ArrayRef order, - ArrayRef outDimNames) { +// Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to +// creating a 1D -> 1D mapping of size product(shape) and then reshaping to +// permute(shape, order). +LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, + ArrayRef order) { assert(shape.size() == order.size()); - MLIRContext *ctx = inDimName.getContext(); + auto rank = shape.size(); + + // The order in triton is written wrt. [dim0, dim1, ...]. + SmallVector outDimNames = standardOutDimNames(ctx, rank); + LinearLayout ret = LinearLayout::empty(); for (int i = 0; i < shape.size(); i++) { // Start with the most-minor dimension, which is order[0]. @@ -491,7 +495,7 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { // And each warp takes the same register and lane sub-layout. So mulitply with // an identity layout for the warp. LinearLayout warpLayout = - identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + identityStandardND(S("warp"), getWarpsPerCTA(), order); LinearLayout ctaLayout = tileLayout * warpLayout; return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); @@ -601,8 +605,7 @@ mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); } - LinearLayout warpLayout = - identityND(kWarp, warpsPerCTA, warpOrder, outDimNames); + LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder); LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * warpLayout.transposeOuts(outDimNames); @@ -684,7 +687,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { // And each warp takes the same register and lane sub-layout. So mulitply with // an identity layout for the warp. LinearLayout warpLayout = - identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + identityStandardND(S("warp"), getWarpsPerCTA(), order); LinearLayout ctaLayout = tileLayout * warpLayout; return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); @@ -700,9 +703,9 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { const auto &order = getOrder(); LinearLayout ctaLayout = - identityND(S("register"), getSizePerThread(), order, outDimNames) * - identityND(S("lane"), getThreadsPerWarp(), order, outDimNames) * - identityND(S("warp"), getWarpsPerCTA(), order, outDimNames); + identityStandardND(S("register"), getSizePerThread(), order) * + identityStandardND(S("lane"), getThreadsPerWarp(), order) * + identityStandardND(S("warp"), getWarpsPerCTA(), order); return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } @@ -711,11 +714,12 @@ LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef tileShape, unsigned kWidth, ArrayRef order, ArrayRef repOrder) { // Trivial layout mapping 0 -> (0, 0), but we set the order to repOrder + // Like LinearLayout::empty() but with a rank and an order int rank = repOrder.size(); auto dimNames = standardOutDimNames(ctx, rank); auto trivialShape = SmallVector(rank, 1); LinearLayout ctaLayout = - identityND(S("register"), trivialShape, repOrder, dimNames); + identityStandardND(S("register"), trivialShape, repOrder); assert(rank >= 2); auto inner = order[0]; @@ -769,11 +773,7 @@ NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef shape) const { // The triton orders are defined on [dim0, dim1, ...], so we need to pass // those dims Then, for some reason, operator* requires the orders to match // so we need to reorder the outs to match - // FIXME(Lezcano). identityND should not take a dim name, as it's redundant. - // The order in triton assumes the standardDims, so it should - // use those. - ctaLayout *= identityND(S("warp"), getWarpsPerCTA(), getWarpOrder(), - standardOutDimNames(ctx, rank)) + ctaLayout *= identityStandardND(S("warp"), getWarpsPerCTA(), getWarpOrder()) .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); @@ -797,11 +797,8 @@ LinearLayout warpsNvidiaDot(MLIRContext *ctx, ArrayRef mmaWarpShape, // In other words, we need to broadcast along K auto rank = mmaWarpOrder.size(); auto inner = isA ? rank - 1 : rank - 2; - auto outer = isA ? rank - 2 : rank - 1; auto dimNames = standardOutDimNames(ctx, rank); - auto trivialShape = SmallVector(rank, 1); - LinearLayout warpLayout = - identityND(S("warp"), trivialShape, mmaWarpOrder, dimNames); + LinearLayout warpLayout = LinearLayout::empty(); // We have to broadcast along the inner dimension // For A, when moving along M we go from 0 to 2. @@ -1086,9 +1083,8 @@ std::optional chooseStMatrixLayoutLeadingOffset( // Expand the `warp` dimension according to warpsPerCTA. auto mma = cast(tensorTy.getEncoding()); - layout *= - identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol}) - .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); // Expand the `register` dimension so the size of columns matches `n`. int n = mma.getInstrShape()[1]; @@ -1126,9 +1122,8 @@ std::optional chooseStMatrixLayoutNoLeadingOffset( LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol); // Expand the `warp` dimension according to warpsPerCTA. - layout *= - identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}, {kRow, kCol}) - .transposeOuts(llvm::to_vector(layout.getOutDimNames())); + layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) + .transposeOuts(llvm::to_vector(layout.getOutDimNames())); auto ret = combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape()); auto tensorShapePerCTA = getShapePerCTA(mma, tensorTy.getShape()); @@ -1138,9 +1133,8 @@ std::optional chooseStMatrixLayoutNoLeadingOffset( ret = ensureLayoutNotSmallerThan(ret, namedTensorShape); ret = ensureLayoutNotLargerThan(ret, namedTensorShape); return ret.transposeOuts(llvm::to_vector(layout.getOutDimNames())) - .reshapeOuts({{S("offset"), ret.getTotalOutDimSize()}, - {S("iteration"), 1}}) * - identityND(kBlock, {1, 1}, {0, 1}, {S("offset"), S("iteration")}); + .reshapeOuts( + {{S("offset"), ret.getTotalOutDimSize()}, {S("iteration"), 1}}); } } // anonymous namespace