From bc55363e092da8b5cb533002100757bec2198604 Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 4 Nov 2024 16:56:10 +0000 Subject: [PATCH] Test Dot. Bypass a proper getOrder use --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 5 ++++- lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | 7 +++++-- .../TritonGPU/Transforms/OptimizeDotOperands.cpp | 11 +++++++++-- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 9e058347f750..a4b50037f31d 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -277,7 +277,7 @@ SmallVector getOrder(Attribute layout) { // Order doesn't really matter. We just have to be consistent when unpacking // the output elements in the LLVM lowerings. We choose row-major auto nvidiaMma = dyn_cast(layout); - if (nvidiaMma && nvidiaMma.isHopper()) { + if (nvidiaMma) { llvm::report_fatal_error("Testing"); } auto distributedLayout = cast(layout); @@ -285,6 +285,9 @@ SmallVector getOrder(Attribute layout) { return getMatrixOrder(rank, /*rowMajor*/ true); } if (auto dotLayout = dyn_cast(layout)) { + if (isa(dotLayout.getParent())) { + llvm::report_fatal_error("Testing DotOperand"); + } auto rank = dotLayout.getWarpsPerCTA().size(); return getOrderForDotOperand(dotLayout.getOpIdx(), rank, /*kMajor*/ true); } diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 2668f384978e..6aab771e2f27 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -292,9 +292,12 @@ LinearLayout ampereMmaToLinearLayout(ArrayRef shape, MLIRContext *ctx = mma.getContext(); SmallVector dimNames = standardOutDimNames(ctx, rank); - auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma)); + // TODO Revert + auto orderedDimNames = + permuteDimNames(dimNames, getMatrixOrder(rank, /*rowMajor=*/true)); + // auto orderedDimNames = permuteDimNames(dimNames, getOrder(mma)); // By using `reverse(dimNames)` below, we set the order to be row-major - assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); + // assert(getOrder(mma) == getMatrixOrder(rank, /*rowMajor=*/true)); LinearLayout ctaLayout( {{S("register"), {{1, 0}, {0, 8}}}, diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 4695984acfd3..cc02cc4cd376 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -228,9 +228,16 @@ class FuseTransHopper : public OpRewritePattern { // MMAv3 with transpose only supports f16 and bf16. Fall back to MMAv3 // without transpose for other data types.) - auto newInnerCvtOrder = getOrder(srcTy.getEncoding()); + // TODO Revert + auto getCvtOrder = [](Attribute encoding) { + if (auto nvidiaMma = dyn_cast(encoding)) { + return getThreadOrder(nvidiaMma); + } + return getOrder(encoding); + }; + auto newInnerCvtOrder = getCvtOrder(srcTy.getEncoding()); if (auto cvt = trans.getSrc().getDefiningOp()) { - newInnerCvtOrder = getOrder(cvt.getSrc().getType().getEncoding()); + newInnerCvtOrder = getCvtOrder(cvt.getSrc().getType().getEncoding()); } auto srcElemTy = allocType.getElementType(); if (!srcElemTy.isF16() && !srcElemTy.isBF16()) {