diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp index 84725a09563c..621430b7e064 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUCastTypeToFitMMA.cpp @@ -23,11 +23,8 @@ namespace mlir::iree_compiler { namespace { -struct UpcastContractOutput : OpRewritePattern { - UpcastContractOutput(MLIRContext *context, - IREE::GPU::MmaInterfaceAttr intrinsic, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), intrinsic(intrinsic) {} +struct UpcastContractOutput final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -40,6 +37,12 @@ struct UpcastContractOutput : OpRewritePattern { auto srcAType = contractOp.getLhsType(); auto srcBType = contractOp.getRhsType(); + auto intrinsic = contractOp->getAttrOfType( + "iree.amdgpu.mma"); + if (!intrinsic) { + return rewriter.notifyMatchFailure( + contractOp, "could not find iree.amdgpu.mma attribute on contract"); + } auto [dstAElemType, dstBElemType, dstCElemType] = intrinsic.getABCElementTypes(); @@ -67,9 +70,6 @@ struct UpcastContractOutput : OpRewritePattern { newContractOp); return success(); } - -private: - IREE::GPU::MmaInterfaceAttr intrinsic; }; struct LLVMGPUCastTypeToFitMMAPass @@ -89,17 +89,24 @@ struct LLVMGPUCastTypeToFitMMAPass func->getAttrOfType(scheduleAttrName); if (!scheduleAttr) { DictionaryAttr configDict = getTranslationInfo(func).getConfiguration(); - scheduleAttr = dyn_cast_or_null( - configDict.get(scheduleAttrName)); + if (configDict) { + scheduleAttr = dyn_cast_or_null( + configDict.get(scheduleAttrName)); + } } - if (!scheduleAttr) { - func.emitError() << "missing mma_schedule\n"; - return signalPassFailure(); + + // Import mma type from dispatch schedule attribute if present. + if (scheduleAttr) { + func.walk([&](vector::ContractionOp contract) { + if (!contract->hasAttr("iree.amdgpu.mma")) { + contract->setAttr("iree.amdgpu.mma", scheduleAttr.getIntrinsic()); + } + }); } MLIRContext *context = &getContext(); RewritePatternSet patterns(context); - patterns.add(context, scheduleAttr.getIntrinsic()); + patterns.add(context); if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { return signalPassFailure(); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir index 1f07bc7c8880..21eb880a4721 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/cast_type_to_fit_mma.mlir @@ -86,3 +86,31 @@ func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf // CHECK-SAME: %[[A]], %[[B]], %[[EXT]] : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32> // CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<48x32xf32> to vector<48x32xf16> // CHECK: return %[[TRUNC]] : vector<48x32xf16> + +// ----- + +// This tests cast_type_to_fit_mma works on IR structure coming out of transform_dialect. + +// IR generated in transform_dialect is different from the one in C++ pipeline. +// it will not have mma_schedule on function attributes, but instead it will have +// "iree.amdgpu.mma" attribute directly on vector.contract. + +func.func @transform_dialect_mfma_matmul_96x64x16(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {translation_info = #iree_codegen.translation_info} { + %0 = vector.contract { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %lhs, %rhs, %init + {iree.amdgpu.mma = #iree_gpu.mma_layout} + : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16> + return %0 : vector<96x64xf16> +} + +// CHECK-LABEL: func.func @transform_dialect_mfma_matmul_96x64x16 +// CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>) +// CHECK: %[[EXT:.+]] = arith.extf %[[INIT]] : vector<96x64xf16> to vector<96x64xf32> +// CHECK: %[[MM:.+]] = vector.contract +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>] +// CHECK-SAME iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind +// CHECK-SAME: %[[A]], %[[B]], %[[EXT]] : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32> +// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16> +// CHECK: return %[[TRUNC]] : vector<96x64xf16>