From 8287311bf24f70885da1bb834aa8eade7de28d3c Mon Sep 17 00:00:00 2001 From: Ilya Veselov Date: Mon, 25 Nov 2024 15:32:29 +0100 Subject: [PATCH] [AMD] Use Linear Layout convertions for AMDWmma Enable LL conwertions for WMMA as well as for MFMA layouts. See also: https://github.com/triton-lang/triton/pull/5210 Signed-off-by: Ilya Veselov --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 7 +- .../amd/tritongpu_wmma_dot_to_llvm.mlir | 65 +++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index aedb18a245e68..99d20c2557435 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -374,13 +374,16 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere) // -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be // completed before we can remove the layoutIsOK check: - // 1. Support for AMD's WMMA + // 1. Support for AMD's WMMA dot operand std::function layoutIsOK = [&](Attribute layout) { if (auto dotOperand = dyn_cast(layout)) { layout = dotOperand.getParent(); + if (isa(layout)) { + return false; + } } - if (isa(layout)) { + if (isa(layout)) { return !useLegacyMMAConversion; } if (isa(layout)) { diff --git a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir index e7dcb873d0642..cfbc5f2ca16ac 100644 --- a/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_wmma_dot_to_llvm.mlir @@ -1,5 +1,6 @@ // RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx1100 | FileCheck %s +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}> #mma1 = #triton_gpu.amd_wmma<{version = 1, warpsPerCTA = [2, 2]}> #mma2 = #triton_gpu.amd_wmma<{version = 2, warpsPerCTA = [2, 2]}> @@ -97,6 +98,70 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-COUNT-8: llvm.insertvalue {{.*}} : !llvm.struct<(f16, f16, f16, f16, f16, f16, f16, f16)> tt.return } + + // CHECK-LABEL: blocked_to_wmma1 + tt.func @blocked_to_wmma1(%arg0: tensor<128x16xi32, #blocked>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma1> + tt.return + } + + // CHECK-LABEL: slice_blocked_to_wmma1 + tt.func @slice_blocked_to_wmma1(%arg0: tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)> + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> + tt.return + } + + // CHECK-LABEL: wmma1_to_blocked + tt.func @wmma1_to_blocked(%arg0: tensor<128x16xi32, #mma1>) { + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma1> -> tensor<128x16xi32, #blocked> + tt.return + } + + // CHECK-LABEL: slice_wmma1_to_blocked + tt.func @slice_wmma1_to_blocked(%arg0: tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>>) { + // CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma1}>> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + tt.return + } + + // CHECK-LABEL: blocked_to_wmma2 + tt.func @blocked_to_wmma2(%arg0: tensor<128x16xi32, #blocked>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-32: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #blocked> -> tensor<128x16xi32, #mma2> + tt.return + } + + // CHECK-LABEL: slice_blocked_to_wmma2 + tt.func @slice_blocked_to_wmma2(%arg0: tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>) { + // CHECK-COUNT-16: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-1: llvm.insertvalue {{.*}} : !llvm.struct<(i32)> + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma2}>> + tt.return + } + + // CHECK-LABEL: wmma2_to_blocked + tt.func @wmma2_to_blocked(%arg0: tensor<128x16xi32, #mma2>) { + // CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<128x16xi32, #mma2> -> tensor<128x16xi32, #blocked> + tt.return + } + + // CHECK-LABEL: slice_wmma2_to_blocked + tt.func @slice_wmma2_to_blocked(%arg0: tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma2}>>) { + // CHECK-COUNT-1: llvm.extractvalue {{.*}} : !llvm.struct<(i32)> + // CHECK-COUNT-16: llvm.insertvalue {{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> + %0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #mma2}>> -> tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + tt.return + } } // -----