Skip to content

Commit

Permalink
[LLVMGPU] Support CastTypeToFitMMA on TransformDialect script. (#17884)
Browse files Browse the repository at this point in the history
Previously CastTypeToFitMMA relies on the `mma_schedule` attribute on
the function's translationInfo to obtain information about
`iree.amdgpu.mma`(intrnisic selected).

While this is fine for C++ pipeline, the IR generated from
TransformDialect script do not have such information. Instead IR
generated in TD script typically annotate the
`iree.amdgpu.mma`(intrnisic selected) directly on the
vector.contractOps.

This is a crucial part of enabling performant the latest attention
compilation pipeline (with online attn + transpose fusion) which is
based on TD scripts.

---------

Co-authored-by: Kunwar Grover <[email protected]>
  • Loading branch information
raikonenfnu and Groverkss authored Jul 12, 2024
1 parent 44808e1 commit be461bd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@ namespace mlir::iree_compiler {

namespace {

struct UpcastContractOutput : OpRewritePattern<vector::ContractionOp> {
UpcastContractOutput(MLIRContext *context,
IREE::GPU::MmaInterfaceAttr intrinsic,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), intrinsic(intrinsic) {}
struct UpcastContractOutput final : OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
Expand All @@ -40,6 +37,12 @@ struct UpcastContractOutput : OpRewritePattern<vector::ContractionOp> {
auto srcAType = contractOp.getLhsType();
auto srcBType = contractOp.getRhsType();

auto intrinsic = contractOp->getAttrOfType<IREE::GPU::MmaInterfaceAttr>(
"iree.amdgpu.mma");
if (!intrinsic) {
return rewriter.notifyMatchFailure(
contractOp, "could not find iree.amdgpu.mma attribute on contract");
}
auto [dstAElemType, dstBElemType, dstCElemType] =
intrinsic.getABCElementTypes();

Expand Down Expand Up @@ -67,9 +70,6 @@ struct UpcastContractOutput : OpRewritePattern<vector::ContractionOp> {
newContractOp);
return success();
}

private:
IREE::GPU::MmaInterfaceAttr intrinsic;
};

struct LLVMGPUCastTypeToFitMMAPass
Expand All @@ -89,17 +89,24 @@ struct LLVMGPUCastTypeToFitMMAPass
func->getAttrOfType<IREE::GPU::MMAScheduleAttr>(scheduleAttrName);
if (!scheduleAttr) {
DictionaryAttr configDict = getTranslationInfo(func).getConfiguration();
scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
configDict.get(scheduleAttrName));
if (configDict) {
scheduleAttr = dyn_cast_or_null<IREE::GPU::MMAScheduleAttr>(
configDict.get(scheduleAttrName));
}
}
if (!scheduleAttr) {
func.emitError() << "missing mma_schedule\n";
return signalPassFailure();

// Import mma type from dispatch schedule attribute if present.
if (scheduleAttr) {
func.walk([&](vector::ContractionOp contract) {
if (!contract->hasAttr("iree.amdgpu.mma")) {
contract->setAttr("iree.amdgpu.mma", scheduleAttr.getIntrinsic());
}
});
}

MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<UpcastContractOutput>(context, scheduleAttr.getIntrinsic());
patterns.add<UpcastContractOutput>(context);

if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
return signalPassFailure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,31 @@ func.func @wmma_matmul_48x32x32_mm(%lhs: vector<48x32xf16>, %rhs: vector<32x32xf
// CHECK-SAME: %[[A]], %[[B]], %[[EXT]] : vector<48x32xf16>, vector<32x32xf16> into vector<48x32xf32>
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<48x32xf32> to vector<48x32xf16>
// CHECK: return %[[TRUNC]] : vector<48x32xf16>

// -----

// This tests cast_type_to_fit_mma works on IR structure coming out of transform_dialect.

// IR generated in transform_dialect is different from the one in C++ pipeline.
// it will not have mma_schedule on function attributes, but instead it will have
// "iree.amdgpu.mma" attribute directly on vector.contract.

func.func @transform_dialect_mfma_matmul_96x64x16(%lhs: vector<96x16xf16>, %rhs: vector<16x64xf16>, %init: vector<96x64xf16>) -> vector<96x64xf16> attributes {translation_info = #iree_codegen.translation_info<None workgroup_size = [64, 1, 1] subgroup_size = 64>} {
%0 = vector.contract {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%lhs, %rhs, %init
{iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>}
: vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf16>
return %0 : vector<96x64xf16>
}

// CHECK-LABEL: func.func @transform_dialect_mfma_matmul_96x64x16
// CHECK-SAME: (%[[A:.+]]: vector<96x16xf16>, %[[B:.+]]: vector<16x64xf16>, %[[INIT:.+]]: vector<96x64xf16>)
// CHECK: %[[EXT:.+]] = arith.extf %[[INIT]] : vector<96x64xf16> to vector<96x64xf32>
// CHECK: %[[MM:.+]] = vector.contract
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]
// CHECK-SAME iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>
// CHECK-SAME: %[[A]], %[[B]], %[[EXT]] : vector<96x16xf16>, vector<16x64xf16> into vector<96x64xf32>
// CHECK: %[[TRUNC:.+]] = arith.truncf %[[MM]] : vector<96x64xf32> to vector<96x64xf16>
// CHECK: return %[[TRUNC]] : vector<96x64xf16>

0 comments on commit be461bd

Please sign in to comment.