From e6a175854d9fe881612571390977a5ceb4a27e54 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 20 Nov 2024 21:42:19 -0500 Subject: [PATCH 1/7] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index a802d62ace1f..b30d0211f466 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -376,7 +376,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // completed before we can remove the layoutIsOK check: // 1. Support for AMD's MFMA and WMMA std::function layoutIsOK = [&](Attribute layout) { - if (auto nvidiaMma = dyn_cast(layout)) { + if (auto nvidiaMma = + isa(layout)) { if (useLegacyMMAConversion) { return false; } @@ -392,6 +393,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return true; } } + if (isa(dotOperand.getParent())) { + return true; + } return false; } if (isa(layout)) { From ba11ef6e59e1b024bdb30859c2046df5a1ce72f9 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 20 Nov 2024 21:42:35 -0500 Subject: [PATCH 2/7] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index b30d0211f466..08cbfcdd8ac9 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -374,7 +374,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere) // -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be // completed before we can remove the layoutIsOK check: - // 1. Support for AMD's MFMA and WMMA + // 1. Support for AMD's WMMA std::function layoutIsOK = [&](Attribute layout) { if (auto nvidiaMma = isa(layout)) { From 484a1462d17b1914cdc27695be30cf6c4cb918e9 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 20 Nov 2024 21:43:13 -0500 Subject: [PATCH 3/7] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 08cbfcdd8ac9..bb52da915238 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -384,11 +384,11 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return true; } if (auto dotOperand = dyn_cast(layout)) { + if (useLegacyMMAConversion) { + return false; + } if (auto nvidiaMma = dyn_cast(dotOperand.getParent())) { - if (useLegacyMMAConversion) { - return false; - } if (nvidiaMma.isAmpere()) { return true; } From 63215c85bf6935f7a94fe36f90aa0559a4401444 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 20 Nov 2024 22:14:34 -0500 Subject: [PATCH 4/7] Update --- test/Conversion/amd/mfma-shortcut.mlir | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 83c9e535d8c0..a2c8f48718d9 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -7,6 +7,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK-NOT: store // CHECK-NOT: load + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } @@ -21,6 +22,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK: store // CHECK: load + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } From 5d4ae07ee575a6dde30b160aec1b93ce632561b3 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 20 Nov 2024 22:16:53 -0500 Subject: [PATCH 5/7] Update --- lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 351f72e16b9e..1ce59184b27d 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -384,16 +384,19 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return true; } if (auto dotOperand = dyn_cast(layout)) { - if (useLegacyMMAConversion) { - return false; - } if (auto nvidiaMma = dyn_cast(dotOperand.getParent())) { + if (useLegacyMMAConversion) { + return false; + } if (nvidiaMma.isAmpere()) { return true; } } if (isa(dotOperand.getParent())) { + if (useLegacyMMAConversion) { + return false; + } return true; } return false; From b70faefc718ad7b1cef175ddb2cdc8ea14a271d5 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 20 Nov 2024 22:19:34 -0500 Subject: [PATCH 6/7] Update --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 1ce59184b27d..d23d15d4567e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -376,27 +376,26 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // completed before we can remove the layoutIsOK check: // 1. Support for AMD's WMMA std::function layoutIsOK = [&](Attribute layout) { - if (auto nvidiaMma = - isa(layout)) { + if (isa(layout)) { if (useLegacyMMAConversion) { return false; } return true; } if (auto dotOperand = dyn_cast(layout)) { - if (auto nvidiaMma = - dyn_cast(dotOperand.getParent())) { + auto parent = dotOperand.getParent(); + if (isa(parent)) { if (useLegacyMMAConversion) { return false; } + return true; + } + if (auto nvidiaMma = dyn_cast(parent)) { if (nvidiaMma.isAmpere()) { return true; } } - if (isa(dotOperand.getParent())) { - if (useLegacyMMAConversion) { - return false; - } + if (isa(parent)) { return true; } return false; From 028688818406f3b8c4b164daf6ed61a51d0a3d5c Mon Sep 17 00:00:00 2001 From: Jokeren Date: Wed, 20 Nov 2024 22:21:44 -0500 Subject: [PATCH 7/7] Update --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index d23d15d4567e..62499d8208cf 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -377,18 +377,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // 1. Support for AMD's WMMA std::function layoutIsOK = [&](Attribute layout) { if (isa(layout)) { - if (useLegacyMMAConversion) { - return false; - } - return true; + return !useLegacyMMAConversion; } if (auto dotOperand = dyn_cast(layout)) { auto parent = dotOperand.getParent(); - if (isa(parent)) { - if (useLegacyMMAConversion) { - return false; - } - return true; + if (isa(parent) && useLegacyMMAConversion) { + return false; } if (auto nvidiaMma = dyn_cast(parent)) { if (nvidiaMma.isAmpere()) {