diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 6a7613453b35..0b0a3b55cc03 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -32,7 +32,7 @@ namespace { #define S(v) StringAttr::get(ctx, (v)) -// Returns ["out0", "out1", ..., "out"]. +// Returns ["dim0", "dim1", ..., "dim"]. SmallVector standardOutDimNames(MLIRContext *ctx, int rank) { SmallVector ret; for (int i = 0; i < rank; i++) { @@ -74,8 +74,8 @@ void assertIsRegisterLayout(const LinearLayout &layout) { // Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to // creating a 1D -> 1D mapping of size product(shape) and then reshaping to // permute(shape, order). -LinearLayout identityND(StringAttr inDimName, ArrayRef shape, - ArrayRef order) { +LinearLayout identityStandardND(StringAttr inDimName, ArrayRef shape, + ArrayRef order) { assert(shape.size() == order.size()); MLIRContext *ctx = inDimName.getContext(); auto rank = shape.size(); @@ -305,7 +305,8 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef shape, {S("lane"), {{2, 0}, {4, 0}, {0, 1}, {0, 2}, {0, 4}}}}, ArrayRef(orderedDimNames).take_front(2)); assert(getWarpOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); - ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder()); + ctaLayout *= + identityStandardND(S("warp"), mma.getWarpsPerCTA(), mma.getWarpOrder()); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); } @@ -342,8 +343,9 @@ LinearLayout hopperMmaToLinearLayout(ArrayRef shape, // Since the warpOrder needs to be M-major, we need to transpose the out // dimensions AND transpose the order assert(getWarpOrder(mma) == SmallVector({0, 1})); - ctaLayout *= identityND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}) - .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); + ctaLayout *= + identityStandardND(S("warp"), mma.getWarpsPerCTA(), /*order=*/{0, 1}) + .transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames())); return combineCtaCgaWithShape(ctaLayout, mma.getCTALayout(), shape); } @@ -558,7 +560,8 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { // And each warp takes the same register and lane sub-layout. So mulitply with // an identity layout for the warp. - LinearLayout warpLayout = identityND(S("warp"), getWarpsPerCTA(), order); + LinearLayout warpLayout = + identityStandardND(S("warp"), getWarpsPerCTA(), order); LinearLayout ctaLayout = tileLayout * warpLayout; return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); @@ -668,7 +671,7 @@ mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout, tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]); } - LinearLayout warpLayout = identityND(kWarp, warpsPerCTA, warpOrder); + LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder); LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) * warpLayout.transposeOuts(outDimNames); @@ -749,7 +752,8 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef shape) const { // And each warp takes the same register and lane sub-layout. So mulitply with // an identity layout for the warp. - LinearLayout warpLayout = identityND(S("warp"), getWarpsPerCTA(), order); + LinearLayout warpLayout = + identityStandardND(S("warp"), getWarpsPerCTA(), order); LinearLayout ctaLayout = tileLayout * warpLayout; return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); @@ -765,9 +769,9 @@ BlockedEncodingAttr::toLinearLayout(ArrayRef shape) const { const auto &order = getOrder(); LinearLayout ctaLayout = - identityND(S("register"), getSizePerThread(), order) * - identityND(S("lane"), getThreadsPerWarp(), order) * - identityND(S("warp"), getWarpsPerCTA(), order); + identityStandardND(S("register"), getSizePerThread(), order) * + identityStandardND(S("lane"), getThreadsPerWarp(), order) * + identityStandardND(S("warp"), getWarpsPerCTA(), order); return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); } @@ -1151,7 +1155,7 @@ std::optional chooseStMatrixLayoutLeadingOffset( // Expand the `warp` dimension according to warpsPerCTA. auto mma = cast(tensorTy.getEncoding()); - layout *= identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) + layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) .transposeOuts(llvm::to_vector(layout.getOutDimNames())); // Expand the `register` dimension so the size of columns matches `n`. @@ -1190,7 +1194,7 @@ std::optional chooseStMatrixLayoutNoLeadingOffset( LinearLayout::identity1D(n / layout.getOutDimSize(kCol), kReg, kCol); // Expand the `warp` dimension according to warpsPerCTA. - layout *= identityND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) + layout *= identityStandardND(kWarp, mma.getWarpsPerCTA(), /*order=*/{0, 1}) .transposeOuts(llvm::to_vector(layout.getOutDimNames())); auto ret = combineCtaCgaWithShape(layout, mma.getCTALayout(), tensorTy.getShape());