Skip to content

Commit

Permalink
[GPU][NFC] Follow the official convention to define mfma/wmma attributes
Browse files Browse the repository at this point in the history
The LLVM intrinsics and official docs are all using
`[output_type]_MxNxK_[input_type]` format. The revision updates IREE's
definitions to follow the convention.

Some examples from official docs:

- https://gpuopen.com/learn/wmma_on_rdna3/
- https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme/
- https://github.com/ROCm/amd_matrix_instruction_calculator

Signed-off-by: hanhanW <[email protected]>
  • Loading branch information
hanhanW committed Aug 6, 2024
1 parent 71f1e20 commit 6020cbf
Show file tree
Hide file tree
Showing 28 changed files with 171 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func.func @contract_to_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<32x8xf16>, vector<8x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<32x32xf32>
Expand Down Expand Up @@ -128,7 +128,7 @@ func.func @contract_to_mfma_16x16x16_mm(%a : vector<16x16xf16>, %b : vector<16x1
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>

%O = iree_vector_ext.to_layout %output to #layout_b : vector<16x16xf32>
Expand Down Expand Up @@ -216,7 +216,7 @@ func.func @contract_to_mfma_32x32x8_mm_mnbatch(%a : vector<64x8xf16>, %b : vecto
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<64x8xf16>, vector<8x32xf16> into vector<64x32xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<64x32xf32>
Expand Down Expand Up @@ -305,7 +305,7 @@ func.func @contract_to_mfma_32x32x8_mm_kbatch(%a : vector<32x16xf16>, %b : vecto
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<32x32xf32>
Expand Down Expand Up @@ -388,7 +388,7 @@ func.func @contract_to_mfma_32x32x8_mm_mnbatch_order(%a : vector<64x8xf16>, %b :
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<64x8xf16>, vector<8x96xf16> into vector<64x96xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<64x96xf32>
Expand Down Expand Up @@ -479,7 +479,7 @@ func.func @contract_to_mfma_32x32x8_mmt(%a : vector<32x8xf16>, %b : vector<64x8x
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<32x8xf16>, vector<64x8xf16> into vector<32x64xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<32x64xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func.func @weight_dequant_matmul() {
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
func.func @conv() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
func.func @conv() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x34x34x1280xf16>>
Expand Down
94 changes: 47 additions & 47 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,19 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return OpaqueMmaLayout{16, 16, 4, f32, f32, f32};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return OpaqueMmaLayout{32, 32, 16, i8, i8, i32};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
Expand Down Expand Up @@ -277,7 +277,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand All @@ -295,7 +295,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]>
// #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]>
// #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX],
Expand All @@ -316,8 +316,8 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand All @@ -334,7 +334,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 8]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand Down Expand Up @@ -437,26 +437,26 @@ MMAAttr::getABCVectorTypes() const {
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({16}, getCType());
Expand Down Expand Up @@ -485,11 +485,11 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return 1;
Expand All @@ -502,11 +502,11 @@ int64_t MMAAttr::getBlockSize() const {
int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return 64;
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
Expand All @@ -524,20 +524,20 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 4}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 4}};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 8}};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 8}};
}
Expand All @@ -556,20 +556,20 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{8, 1}};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{8, 1}};
}
Expand All @@ -585,14 +585,14 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
Expand Down Expand Up @@ -632,11 +632,11 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
rhs, acc)
.getResult();
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
Expand Down Expand Up @@ -716,8 +716,8 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides(
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
break;
default:
return failure();
Expand Down
20 changes: 10 additions & 10 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,23 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>

// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>;
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 1>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 2>;
def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 3>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>;
def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 1>;
def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 2>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 3>;
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 4>;
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 5>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 6>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>;

def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F32_16x16x4_F32,
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
MFMA_F8E4M3FNUZ_16x16x32_F32,
MFMA_I8_16x16x32_I32,
MFMA_I8_32x32x16_I32,
MFMA_F32_16x16x16_F16,
MFMA_F32_32x32x8_F16,
MFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_I32_16x16x32_I8,
MFMA_I32_32x32x16_I8,
WMMA_F16_16x16x16_F32,
WMMA_F16_16x16x16_F16
]>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def IREEGPU_MultiMmaOp : Op<IREEGPU_Dialect, "multi_mma", [
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
: vector<2x3x4xf16>, vector<3x5x4xf16> into vector<2x5x4xf32>
Expand Down Expand Up @@ -99,7 +99,7 @@ def IREEGPU_MultiMmaOp : Op<IREEGPU_Dialect, "multi_mma", [
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = [],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
: vector<4xf16>, vector<4xf16> into vector<4xf32>
Expand Down Expand Up @@ -127,7 +127,7 @@ def IREEGPU_MultiMmaOp : Op<IREEGPU_Dialect, "multi_mma", [
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
rhs_permutation = [1, 0]
}
%7 = iree_gpu.multi_mma %4, %5, %6 #contraction_trait
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@

module {
func.func @test_mfma_f16_16x16x16_f32() attributes {
mma_types = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>} {
mma_types = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>} {
return
}
}
// CHECK-LABEL: func @test_mfma_f16_16x16x16_f32
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>

module {
func.func @test_mfma_f16_32x32x8_f32() attributes {
mma_types = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>} {
mma_types = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} {
return
}
}
// CHECK-LABEL: func @test_mfma_f16_32x32x8_f32
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>

module {
func.func @test_wmma_f16_16x16x16_f32() attributes {
Expand Down
Loading

0 comments on commit 6020cbf

Please sign in to comment.