Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU][NFC] Follow the official convention to define mfma/wmma attributes #18127

Merged
merged 2 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func.func @contract_to_mfma_32x32x8_mm(%a : vector<32x8xf16>, %b : vector<8x32xf
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<32x8xf16>, vector<8x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<32x32xf32>
Expand Down Expand Up @@ -128,7 +128,7 @@ func.func @contract_to_mfma_16x16x16_mm(%a : vector<16x16xf16>, %b : vector<16x1
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>

%O = iree_vector_ext.to_layout %output to #layout_b : vector<16x16xf32>
Expand Down Expand Up @@ -216,7 +216,7 @@ func.func @contract_to_mfma_32x32x8_mm_mnbatch(%a : vector<64x8xf16>, %b : vecto
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<64x8xf16>, vector<8x32xf16> into vector<64x32xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<64x32xf32>
Expand Down Expand Up @@ -305,7 +305,7 @@ func.func @contract_to_mfma_32x32x8_mm_kbatch(%a : vector<32x16xf16>, %b : vecto
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<32x16xf16>, vector<16x32xf16> into vector<32x32xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<32x32xf32>
Expand Down Expand Up @@ -388,7 +388,7 @@ func.func @contract_to_mfma_32x32x8_mm_mnbatch_order(%a : vector<64x8xf16>, %b :
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<64x8xf16>, vector<8x96xf16> into vector<64x96xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<64x96xf32>
Expand Down Expand Up @@ -479,7 +479,7 @@ func.func @contract_to_mfma_32x32x8_mmt(%a : vector<32x8xf16>, %b : vector<64x8x
indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>,
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
iree.amdgpu.mma = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>
} %A, %B, %C : vector<32x8xf16>, vector<64x8xf16> into vector<32x64xf32>

%O = iree_vector_ext.to_layout %output to #layout_c : vector<32x64xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func.func @weight_dequant_matmul() {
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
func.func @conv() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
func.func @conv() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {mma_schedule = #iree_gpu.mma_schedule<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 1, subgroup_n_count = 4>}>} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x34x34x1280xf16>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,19 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return OpaqueMmaLayout{16, 16, 4, f32, f32, f32};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return OpaqueMmaLayout{32, 32, 8, f16, f16, f32};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
}
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return OpaqueMmaLayout{16, 16, 32, i8, i8, i32};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return OpaqueMmaLayout{32, 32, 16, i8, i8, i32};
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
Expand Down Expand Up @@ -277,7 +277,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand All @@ -295,7 +295,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]>
// #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]>
// #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX],
Expand All @@ -316,8 +316,8 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 8]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand All @@ -334,7 +334,7 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 8]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
Expand Down Expand Up @@ -437,26 +437,26 @@ MMAAttr::getABCVectorTypes() const {
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
auto cType = VectorType::get({16}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto aType = VectorType::get({8}, getAType());
auto bType = VectorType::get({8}, getBType());
auto cType = VectorType::get({16}, getCType());
Expand Down Expand Up @@ -485,11 +485,11 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8:
case MMAIntrinsic::WMMA_F16_16x16x16_F16:
case MMAIntrinsic::WMMA_F16_16x16x16_F32: {
return 1;
Expand All @@ -502,11 +502,11 @@ int64_t MMAAttr::getBlockSize() const {
int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return 64;
}
case MMAIntrinsic::WMMA_F16_16x16x16_F32:
Expand All @@ -524,20 +524,20 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 4}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 4}};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 8}};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{32, 2}, /*strides=*/{1, 32},
/*element=*/{1, 8}};
}
Expand All @@ -556,20 +556,20 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{8, 1}};
}
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{8, 1}};
}
Expand All @@ -585,14 +585,14 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
}
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
return {/*outer=*/{4, 1}, /*thread=*/{2, 32}, /*strides=*/{32, 1},
/*element=*/{4, 1}};
}
Expand Down Expand Up @@ -632,11 +632,11 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
rhs, acc)
.getResult();
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_I8_32x32x16_I32: {
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_F32_32x32x8_F16:
case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
case MMAIntrinsic::MFMA_I32_32x32x16_I8: {
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
Expand Down Expand Up @@ -716,8 +716,8 @@ LogicalResult MMAAttr::populateOperandOffsetsSizesStrides(
SmallVector<OpFoldResult> &offsets, SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
case MMAIntrinsic::MFMA_F32_16x16x16_F16:
case MMAIntrinsic::MFMA_I32_16x16x32_I8:
break;
default:
return failure();
Expand Down
20 changes: 10 additions & 10 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,23 +100,23 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>

// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>;
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 1>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 2>;
def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 3>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>;
def MFMA_F32_16x16x16_F16 : I32EnumAttrCase<"MFMA_F32_16x16x16_F16", 1>;
def MFMA_F32_32x32x8_F16 : I32EnumAttrCase<"MFMA_F32_32x32x8_F16", 2>;
def MFMA_F32_16x16x32_F8E4M3FNUZ : I32EnumAttrCase<"MFMA_F32_16x16x32_F8E4M3FNUZ", 3>;
def MFMA_I32_16x16x32_I8 : I32EnumAttrCase<"MFMA_I32_16x16x32_I8", 4>;
def MFMA_I32_32x32x16_I8 : I32EnumAttrCase<"MFMA_I32_32x32x16_I8", 5>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 6>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>;

def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F32_16x16x4_F32,
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
MFMA_F8E4M3FNUZ_16x16x32_F32,
MFMA_I8_16x16x32_I32,
MFMA_I8_32x32x16_I32,
MFMA_F32_16x16x16_F16,
MFMA_F32_32x32x8_F16,
MFMA_F32_16x16x32_F8E4M3FNUZ,
MFMA_I32_16x16x32_I8,
MFMA_I32_32x32x16_I8,
WMMA_F16_16x16x16_F32,
WMMA_F16_16x16x16_F16
]>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def IREEGPU_MultiMmaOp : Op<IREEGPU_Dialect, "multi_mma", [
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
: vector<2x3x4xf16>, vector<3x5x4xf16> into vector<2x5x4xf32>
Expand Down Expand Up @@ -99,7 +99,7 @@ def IREEGPU_MultiMmaOp : Op<IREEGPU_Dialect, "multi_mma", [
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = [],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
}
%3 = iree_gpu.multi_mma %0, %1, %2 #contraction_trait
: vector<4xf16>, vector<4xf16> into vector<4xf32>
Expand Down Expand Up @@ -127,7 +127,7 @@ def IREEGPU_MultiMmaOp : Op<IREEGPU_Dialect, "multi_mma", [
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>,
kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>,
rhs_permutation = [1, 0]
}
%7 = iree_gpu.multi_mma %4, %5, %6 #contraction_trait
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@

module {
func.func @test_mfma_f16_16x16x16_f32() attributes {
mma_types = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>} {
mma_types = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>} {
return
}
}
// CHECK-LABEL: func @test_mfma_f16_16x16x16_f32
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F16_16x16x16_F32>
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>

module {
func.func @test_mfma_f16_32x32x8_f32() attributes {
mma_types = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>} {
mma_types = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>} {
return
}
}
// CHECK-LABEL: func @test_mfma_f16_32x32x8_f32
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F16_32x32x8_F32>
// CHECK-SAME: mma_types = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>

module {
func.func @test_wmma_f16_16x16x16_f32() attributes {
Expand Down
Loading
Loading