From ae00c4f3cb42abdc2078ab33c703b2a09935f6c1 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Thu, 25 Jul 2024 01:36:18 +0530 Subject: [PATCH] =?UTF-8?q?Revert=20"Revert=20"[LLVMGPU][ROCm]=20Add=20MFM?= =?UTF-8?q?A=5FF32=5F16x16x4=5FF32=20instruction"=E2=80=A6=20(#17921)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit … (#17894)" This reverts commit 02c2000795e157e4cf63fbac89d21a1ed886a7b0. --- .../ROCM/test/target_device_features.mlir | 4 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 49 +++++++++++++++++++ .../Codegen/Dialect/GPU/IR/IREEGPUEnums.td | 8 +-- .../Dialect/GPU/TargetUtils/KnownTargets.cpp | 1 + .../Preprocessing/Common/PadToIntrinsics.cpp | 13 +++++ tests/e2e/matmul/CMakeLists.txt | 28 +++++++++++ tests/e2e/matmul/generate_e2e_matmul_tests.py | 16 +++++- 7 files changed, 112 insertions(+), 7 deletions(-) diff --git a/compiler/plugins/target/ROCM/test/target_device_features.mlir b/compiler/plugins/target/ROCM/test/target_device_features.mlir index 7859c0baf110..5973c05c8488 100644 --- a/compiler/plugins/target/ROCM/test/target_device_features.mlir +++ b/compiler/plugins/target/ROCM/test/target_device_features.mlir @@ -6,13 +6,13 @@ // GFX942: target = #iree_gpu.target, , , , ], +// GFX942-SAME: mma = [, , , , , ], // GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], // GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>, // GFX942-SAME: chip = > // GFX940: target = #iree_gpu.target, , , , ], +// GFX940-SAME: mma = [, , , , , ], // GFX1100: target = #iree_gpu.target, ] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 2e4d771991db..86e7bc074fb9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -211,6 +211,9 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context, Type i32 = IntegerType::get(context, 32); switch (type) { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: { + return OpaqueMmaLayout{16, 16, 4, f32, f32, f32}; + } case MMAIntrinsic::MFMA_F16_16x16x16_F32: { return OpaqueMmaLayout{16, 16, 16, f16, f16, f32}; } @@ -255,6 +258,24 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context, LayoutDimensionAttr::get(context, LayoutDimension::VECTORZ); (void)laneZ, (void)vectorZ; switch (type) { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: { + // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> + // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 1]> + // #layout_a = #iree_vector_ext.layout<#outer, #inner> + // #layout_b = #iree_vector_ext.layout<#inner, #outer> + // #layout_c = #iree_vector_ext.layout<#inner, #outer> + + auto outer = PerDimLayoutAttr::get(context, {laneX}, {16}); + auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 1}); + auto aMLayout = outer; + auto aKLayout = inner; + auto bKLayout = inner; + auto bNLayout = outer; + auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4}); + auto cNLayout = outer; + return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout, + bNLayout, cMLayout, cNLayout}; + } case MMAIntrinsic::MFMA_F16_16x16x16_F32: { // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]> // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]> @@ -409,6 +430,12 @@ MMAAttr::getABCVectorTypes() const { // amd_matrix_instruction_calculator tells us about the number of 32-bit // registers. So need to adjust accordingly. All vectors should be 1-D. switch (getIntrinsic().getValue()) { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: { + auto aType = VectorType::get({1}, getAType()); + auto bType = VectorType::get({1}, getBType()); + auto cType = VectorType::get({4}, getCType()); + return std::make_tuple(aType, bType, cType); + } case MMAIntrinsic::MFMA_F16_16x16x16_F32: { auto aType = VectorType::get({4}, getAType()); auto bType = VectorType::get({4}, getBType()); @@ -456,6 +483,7 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const { int64_t MMAAttr::getBlockSize() const { switch (getIntrinsic().getValue()) { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: case MMAIntrinsic::MFMA_F16_16x16x16_F32: case MMAIntrinsic::MFMA_F16_32x32x8_F32: case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: @@ -472,6 +500,7 @@ int64_t MMAAttr::getBlockSize() const { int64_t MMAAttr::getSubgroupSize() const { switch (getIntrinsic().getValue()) { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: case MMAIntrinsic::MFMA_F16_16x16x16_F32: case MMAIntrinsic::MFMA_F16_32x32x8_F32: case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: @@ -490,6 +519,10 @@ int64_t MMAAttr::getSubgroupSize() const { MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const { switch (getIntrinsic().getValue()) { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: { + return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16}, + /*element=*/{1, 1}}; + } case MMAIntrinsic::MFMA_F16_16x16x16_F32: { return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16}, /*element=*/{1, 4}}; @@ -518,6 +551,10 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const { MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const { switch (getIntrinsic().getValue()) { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: { + return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1}, + /*element=*/{1, 1}}; + } case MMAIntrinsic::MFMA_F16_16x16x16_F32: { return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1}, /*element=*/{4, 1}}; @@ -546,6 +583,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const { MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const { switch (getIntrinsic().getValue()) { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: case MMAIntrinsic::MFMA_F16_16x16x16_F32: case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: case MMAIntrinsic::MFMA_I8_16x16x32_I32: { @@ -582,6 +620,17 @@ FailureOr MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc, return failure(); } switch (getIntrinsic().getValue()) { + case MMAIntrinsic::MFMA_F32_16x16x4_F32: { + // Update the lhs and rhs to extract the first element since vector<1xT> is + // not supoorted by amgpu.mfma op. + lhs = builder.create(loc, lhs, ArrayRef{int64_t{0}}); + rhs = builder.create(loc, rhs, ArrayRef{int64_t{0}}); + auto [m, n, k] = getMNKShape(); + return builder + .create(loc, resultType, m, n, k, getBlockSize(), lhs, + rhs, acc) + .getResult(); + } case MMAIntrinsic::MFMA_F16_16x16x16_F32: case MMAIntrinsic::MFMA_F16_32x32x8_F32: case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td index 108deb979542..55a83b3fc131 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td @@ -99,9 +99,10 @@ class IREEGPU_I32MmaEnumAttr } // Format: __xx_ -def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>; -def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>; -def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 2>; +def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>; +def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 1>; +def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 2>; +def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 3>; def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>; def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>; // TODO: Create separate WMMA ops for AMD and NVIDIA GPUs @@ -110,6 +111,7 @@ def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>; def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic", "Descriptor for different MMA intrinsics", [ + MFMA_F32_16x16x4_F32, MFMA_F16_16x16x16_F32, MFMA_F16_32x32x8_F32, MFMA_F8E4M3FNUZ_16x16x32_F32, diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp index c9e200f221e7..2f3d254b6587 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp @@ -122,6 +122,7 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch, const WgpDetails *getCDNA3WgpDetails() { static const MMAIntrinsic cdna3MMAOps[] = { + MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F16_16x16x16_F32, MMAIntrinsic::MFMA_F16_32x32x8_F32, MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32, diff --git a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp index 63ea3107bdb3..721522afaef5 100644 --- a/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp +++ b/compiler/src/iree/compiler/Preprocessing/Common/PadToIntrinsics.cpp @@ -214,10 +214,23 @@ static void padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp) { int64_t nSize = bounds[nDim]; int64_t kSize = bounds[kDim]; + auto inpElemType = + cast(linalgOp.getDpsInputOperand(0)->get().getType()) + .getElementType(); + auto kernelElemType = + cast(linalgOp.getDpsInputOperand(1)->get().getType()) + .getElementType(); + // TODO: Generalize to other dimensions. // Try to search for pad value and check only filter dimension is blocked. SmallVector> mnkPaddingCandidates; for (const GPUMatmulShapeType &intrinsic : intrinsics) { + + if (!(inpElemType == intrinsic.aType && + kernelElemType == intrinsic.bType)) { + continue; + } + std::optional mPadding, nPadding, kPadding; auto getPadding = [](int64_t value, int64_t padTo) { return llvm::divideCeil(value, padTo) * padTo - value; diff --git a/tests/e2e/matmul/CMakeLists.txt b/tests/e2e/matmul/CMakeLists.txt index c908a8a8c4fe..0556e756bef6 100644 --- a/tests/e2e/matmul/CMakeLists.txt +++ b/tests/e2e/matmul/CMakeLists.txt @@ -2289,6 +2289,34 @@ iree_generated_e2e_runner_test( "requires-gpu-cdna3" ) +iree_generated_e2e_runner_test( + NAME + e2e_matmul_rocm_f32_large_cdna3_mfma + TEST_TYPE + matmul + GENERATOR + "generate_e2e_matmul_tests.py" + GENERATOR_ARGS + "--lhs_rhs_type=f32" + "--acc_type=f32" + "--shapes=gpu_large_aligned" + "--compilation_info=LLVMGPUVectorDistributeMFMA" + TEST_RUNNER + iree_tools_testing_e2e_iree-e2e-matmul-test + TARGET_BACKENDS + "rocm" + DRIVERS + "hip" + COMPILER_FLAGS + ${IREE_HIP_TEST_COMPILER_FLAGS} + LABELS + "noasan" + "nomsan" + "notsan" + "noubsan" + "requires-gpu-cdna3" +) + iree_generated_e2e_runner_test( NAME e2e_matmul_rocm_f16_large_cdna3_mfma_tb diff --git a/tests/e2e/matmul/generate_e2e_matmul_tests.py b/tests/e2e/matmul/generate_e2e_matmul_tests.py index 8b1fbf615b3b..56d590bad7d8 100644 --- a/tests/e2e/matmul/generate_e2e_matmul_tests.py +++ b/tests/e2e/matmul/generate_e2e_matmul_tests.py @@ -261,6 +261,11 @@ def get_rocm_test_compilation_infos( schedules = [] if intrinsic == "MFMA": schedules = [ + MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 1), + MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 2), + MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 2, 1), + MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 2, 1, 1), + MMASchedule("MFMA_F32_16x16x4_F32", 2, 2, 1, 1, 2), MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1), MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2), MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1), @@ -304,10 +309,17 @@ def get_rocm_test_compilation_infos( for schedule in schedules: # Skip schedules with an intrinsic which element type does not # match the requested one. - if lhs_rhs_type.value.upper() not in schedule.intrinsic: + # Extracts the input type from strings containing either 'MFMA' or 'WMMA' + # followed by an underscore. + extract_input_type = lambda s: re.search(r"(?:MFMA|WMMA)_([^_]+)_", s).group(1) + if lhs_rhs_type.value.upper() != extract_input_type(schedule.intrinsic): continue - if schedule.intrinsic == "MFMA_F16_16x16x16_F32": + if schedule.intrinsic == "MFMA_F32_16x16x4_F32": + wg_tile_m = schedule.m_count * schedule.m_tile_count * 16 + wg_tile_n = schedule.n_count * schedule.n_tile_count * 16 + wg_tile_k = schedule.k_tile_count * 4 + elif schedule.intrinsic == "MFMA_F16_16x16x16_F32": wg_tile_m = schedule.m_count * schedule.m_tile_count * 16 wg_tile_n = schedule.n_count * schedule.n_tile_count * 16 wg_tile_k = schedule.k_tile_count * 16