diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 9e058347f750c..a4b50037f31d5 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -277,7 +277,7 @@ SmallVector getOrder(Attribute layout) { // Order doesn't really matter. We just have to be consistent when unpacking // the output elements in the LLVM lowerings. We choose row-major auto nvidiaMma = dyn_cast(layout); - if (nvidiaMma && nvidiaMma.isHopper()) { + if (nvidiaMma) { llvm::report_fatal_error("Testing"); } auto distributedLayout = cast(layout); @@ -285,6 +285,9 @@ SmallVector getOrder(Attribute layout) { return getMatrixOrder(rank, /*rowMajor*/ true); } if (auto dotLayout = dyn_cast(layout)) { + if (isa(dotLayout.getParent())) { + llvm::report_fatal_error("Testing DotOperand"); + } auto rank = dotLayout.getWarpsPerCTA().size(); return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 2668f384978e5..7cd082d747145 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -292,9 +292,12 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); SmallVector dimNames = standardOutDimNames(ctx, rank); - auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma)); + // TODO Revert + auto orderedDimNames = + permuteDimNames(dimNames, getMatrixOrder(rank, /*rowMajor=*/true)); + // auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma)); // By using `reverse(dimNames)` below, we set the order to be row-major - assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); + // assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, @@ -876,11 +879,15 @@ LinearLayout ampereDotToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); // A and B have kMajor order - assert(getOrder(dot) == - getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); - - auto kMajorDims = - permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot)); + // TODO Revert + // assert(getOrder(dot) == + // getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); + + // auto kMajorDims = + // permuteDimNames(standardOutDimNames(ctx, rank), getOrder(dot)); + auto kMajorDims = permuteDimNames( + standardOutDimNames(ctx, rank), + getOrderForDotOperand(dot.getOpIdx(), rank, /*kMajor=*/true)); // Implement A. For B transpose in the end std::vector> registers; diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 4695984acfd3b..cc02cc4cd376d 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -228,9 +228,16 @@ class FuseTransHopper : public OpRewritePattern { // MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3 // without transpose for other data types.) - auto newInnerCvtOrder = getOrder(srcTy.getEncoding()); + // TODO Revert + auto getCvtOrder = [](Attribute encoding) { + if (auto nvidiaMma = dyn_cast(encoding)) { + return getThreadOrder(nvidiaMma); + } + return getOrder(encoding); + }; + auto newInnerCvtOrder = getCvtOrder(srcTy.getEncoding()); if (auto cvt = trans.getSrc().getDefiningOp()) { - newInnerCvtOrder = getOrder(cvt.getSrc().getType().getEncoding()); + newInnerCvtOrder = getCvtOrder(cvt.getSrc().getType().getEncoding()); } auto srcElemTy = allocType.getElementType(); if (!srcElemTy.isF16() && !srcElemTy.isBF16()) {