Skip to content

Commit

Permalink
[LLVMGPU][ROCm] Add MFMA_F32_16x16x4_F32 instruction (iree-org#17847)
Browse files Browse the repository at this point in the history
Signed-off-by: Lubo Litchev <[email protected]>
  • Loading branch information
pashu123 authored and LLITCHEV committed Jul 30, 2024
1 parent f58bcb0 commit f322f9b
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 13 deletions.
4 changes: 2 additions & 2 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: chip = <wgp_count = 304>>

// GFX940: target = #iree_gpu.target<arch = "gfx940",
// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],

// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
Expand Down
49 changes: 49 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
Type i32 = IntegerType::get(context, 32);

switch (type) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return OpaqueMmaLayout{16, 16, 4, f32, f32, f32};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
Expand Down Expand Up @@ -251,6 +254,24 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
LayoutDimensionAttr::get(context, LayoutDimension::VECTORZ);
(void)laneZ, (void)vectorZ;
switch (type) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 1]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>
// #layout_c = #iree_vector_ext.layout<#inner, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 1});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
auto cNLayout = outer;
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
Expand Down Expand Up @@ -404,6 +425,12 @@ MMAAttr::getABCVectorTypes() const {
// amd_matrix_instruction_calculator tells us about the number of 32-bit
// registers. So need to adjust accordingly. All vectors should be 1-D.
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
auto aType = VectorType::get({1}, getAType());
auto bType = VectorType::get({1}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
Expand Down Expand Up @@ -450,6 +477,7 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {

int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
Expand All @@ -465,6 +493,7 @@ int64_t MMAAttr::getBlockSize() const {

int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
Expand All @@ -482,6 +511,10 @@ int64_t MMAAttr::getSubgroupSize() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 4}};
Expand Down Expand Up @@ -509,6 +542,10 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
Expand Down Expand Up @@ -536,6 +573,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
Expand Down Expand Up @@ -571,6 +609,17 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
return failure();
}
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
// Update the lhs and rhs to extract the first element since vector<1xT> is
// not supoorted by amgpu.mfma op.
lhs = builder.create<vector::ExtractOp>(loc, lhs, ArrayRef{int64_t{0}});
rhs = builder.create<vector::ExtractOp>(loc, rhs, ArrayRef{int64_t{0}});
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
rhs, acc)
.getResult();
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,18 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>
}

// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 2>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 3>;
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>;
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 1>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 2>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 3>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 4>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 4>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 5>;
def WMMA_F16_16x16x16_F32 : I32EnumAttrCase<"WMMA_F16_16x16x16_F32", 5>;
def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 6>;

def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F32_16x16x4_F32,
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
MFMA_I8_16x16x32_I32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,8 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,

const WgpDetails *getCDNA3WgpDetails() {
static const MMAIntrinsic cdna3MMAOps[] = {
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
MMAIntrinsic::MFMA_I8_16x16x32_I32,
MMAIntrinsic::MFMA_F32_16x16x4_F32, MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32, MMAIntrinsic::MFMA_I8_16x16x32_I32,
MMAIntrinsic::MFMA_I8_32x32x16_I32,
};
static const WgpDetails cdna3Wgp = {
Expand Down
28 changes: 28 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2193,6 +2193,34 @@ iree_generated_e2e_runner_test(
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_f32_large_cdna3_mfma
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
"--shapes=gpu_large_aligned"
"--compilation_info=LLVMGPUVectorDistributeMFMA"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_f16_large_cdna3_mfma_tb
Expand Down
15 changes: 13 additions & 2 deletions tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ def get_rocm_test_compilation_infos(
schedules = []
if intrinsic == "MFMA":
schedules = [
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 2, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 2, 1, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1),
Expand Down Expand Up @@ -299,10 +304,16 @@ def get_rocm_test_compilation_infos(
for schedule in schedules:
# Skip schedules with an intrinsic which element type does not
# match the requested one.
if lhs_rhs_type.value.upper() not in schedule.intrinsic:
# Search for the lhs_rhs type in the first part of intrinsic
# e.g., MFMA_F32_16x16x4_F32 -> MFMA_F32
if lhs_rhs_type.value.upper() not in schedule.intrinsic[:8]:
continue

if schedule.intrinsic == "MFMA_F16_16x16x16_F32":
if schedule.intrinsic == "MFMA_F32_16x16x4_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 4
elif schedule.intrinsic == "MFMA_F16_16x16x16_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 16
Expand Down

0 comments on commit f322f9b

Please sign in to comment.