Skip to content

Commit

Permalink
Revert "Revert "[LLVMGPU][ROCm] Add MFMA_F32_16x16x4_F32 instruction"… (
Browse files Browse the repository at this point in the history
iree-org#17921)

… (iree-org#17894)"

This reverts commit 02c2000.

Signed-off-by: Lubo Litchev <[email protected]>
  • Loading branch information
pashu123 authored and LLITCHEV committed Jul 30, 2024
1 parent ed94d3e commit cea7ec4
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 7 deletions.
4 changes: 2 additions & 2 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
// GFX942: target = #iree_gpu.target<arch = "gfx942",
// GFX942-SAME: wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8,
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
// GFX942-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
// GFX942-SAME: chip = <wgp_count = 304>>

// GFX940: target = #iree_gpu.target<arch = "gfx940",
// GFX940-SAME: mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],
// GFX940-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_F8E4M3FNUZ_16x16x32_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>],

// GFX1100: target = #iree_gpu.target<arch = "gfx1100",
// GFX1100-SAME: mma = [<WMMA_F16_16x16x16_F32>, <WMMA_F16_16x16x16_F16>]
Expand Down
49 changes: 49 additions & 0 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
Type i32 = IntegerType::get(context, 32);

switch (type) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return OpaqueMmaLayout{16, 16, 4, f32, f32, f32};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return OpaqueMmaLayout{16, 16, 16, f16, f16, f32};
}
Expand Down Expand Up @@ -255,6 +258,24 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
LayoutDimensionAttr::get(context, LayoutDimension::VECTORZ);
(void)laneZ, (void)vectorZ;
switch (type) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 1]>
// #layout_a = #iree_vector_ext.layout<#outer, #inner>
// #layout_b = #iree_vector_ext.layout<#inner, #outer>
// #layout_c = #iree_vector_ext.layout<#inner, #outer>

auto outer = PerDimLayoutAttr::get(context, {laneX}, {16});
auto inner = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 1});
auto aMLayout = outer;
auto aKLayout = inner;
auto bKLayout = inner;
auto bNLayout = outer;
auto cMLayout = PerDimLayoutAttr::get(context, {laneY, vectorX}, {4, 4});
auto cNLayout = outer;
return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
bNLayout, cMLayout, cNLayout};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
// #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
// #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
Expand Down Expand Up @@ -409,6 +430,12 @@ MMAAttr::getABCVectorTypes() const {
// amd_matrix_instruction_calculator tells us about the number of 32-bit
// registers. So need to adjust accordingly. All vectors should be 1-D.
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
auto aType = VectorType::get({1}, getAType());
auto bType = VectorType::get({1}, getBType());
auto cType = VectorType::get({4}, getCType());
return std::make_tuple(aType, bType, cType);
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
auto aType = VectorType::get({4}, getAType());
auto bType = VectorType::get({4}, getBType());
Expand Down Expand Up @@ -456,6 +483,7 @@ MMAAttr::getContractionLayout(vector::ContractionOp contract) const {

int64_t MMAAttr::getBlockSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
Expand All @@ -472,6 +500,7 @@ int64_t MMAAttr::getBlockSize() const {

int64_t MMAAttr::getSubgroupSize() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
Expand All @@ -490,6 +519,10 @@ int64_t MMAAttr::getSubgroupSize() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{16, 4}, /*strides=*/{1, 16},
/*element=*/{1, 4}};
Expand Down Expand Up @@ -518,6 +551,10 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getASingleSubgroupLayout() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{1, 1}};
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32: {
return {/*outer=*/{1, 1}, /*thread=*/{4, 16}, /*strides=*/{16, 1},
/*element=*/{4, 1}};
Expand Down Expand Up @@ -546,6 +583,7 @@ MMAAttr::SingleSubgroupLayout MMAAttr::getBSingleSubgroupLayout() const {

MMAAttr::SingleSubgroupLayout MMAAttr::getCSingleSubgroupLayout() const {
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32:
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
case MMAIntrinsic::MFMA_I8_16x16x32_I32: {
Expand Down Expand Up @@ -582,6 +620,17 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
return failure();
}
switch (getIntrinsic().getValue()) {
case MMAIntrinsic::MFMA_F32_16x16x4_F32: {
// Update the lhs and rhs to extract the first element since vector<1xT> is
// not supoorted by amgpu.mfma op.
lhs = builder.create<vector::ExtractOp>(loc, lhs, ArrayRef{int64_t{0}});
rhs = builder.create<vector::ExtractOp>(loc, rhs, ArrayRef{int64_t{0}});
auto [m, n, k] = getMNKShape();
return builder
.create<amdgpu::MFMAOp>(loc, resultType, m, n, k, getBlockSize(), lhs,
rhs, acc)
.getResult();
}
case MMAIntrinsic::MFMA_F16_16x16x16_F32:
case MMAIntrinsic::MFMA_F16_32x32x8_F32:
case MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ class IREEGPU_I32MmaEnumAttr<string name, string summary, list<I32EnumAttrCase>
}

// Format: <kind>_<input-type>_<M>x<N>x<K>_<output-type>
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 0>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 1>;
def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 2>;
def MFMA_F32_16x16x4_F32 : I32EnumAttrCase<"MFMA_F32_16x16x4_F32", 0>;
def MFMA_F16_16x16x16_F32 : I32EnumAttrCase<"MFMA_F16_16x16x16_F32", 1>;
def MFMA_F16_32x32x8_F32 : I32EnumAttrCase<"MFMA_F16_32x32x8_F32", 2>;
def MFMA_F8E4M3FNUZ_16x16x32_F32 : I32EnumAttrCase<"MFMA_F8E4M3FNUZ_16x16x32_F32", 3>;
def MFMA_I8_16x16x32_I32 : I32EnumAttrCase<"MFMA_I8_16x16x32_I32", 4>;
def MFMA_I8_32x32x16_I32 : I32EnumAttrCase<"MFMA_I8_32x32x16_I32", 5>;
// TODO: Create separate WMMA ops for AMD and NVIDIA GPUs
Expand All @@ -110,6 +111,7 @@ def WMMA_F16_16x16x16_F16 : I32EnumAttrCase<"WMMA_F16_16x16x16_F16", 7>;

def IREEGPU_MMAIntrinsic : IREEGPU_I32MmaEnumAttr<"MMAIntrinsic",
"Descriptor for different MMA intrinsics", [
MFMA_F32_16x16x4_F32,
MFMA_F16_16x16x16_F32,
MFMA_F16_32x32x8_F32,
MFMA_F8E4M3FNUZ_16x16x32_F32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,

const WgpDetails *getCDNA3WgpDetails() {
static const MMAIntrinsic cdna3MMAOps[] = {
MMAIntrinsic::MFMA_F32_16x16x4_F32,
MMAIntrinsic::MFMA_F16_16x16x16_F32,
MMAIntrinsic::MFMA_F16_32x32x8_F32,
MMAIntrinsic::MFMA_F8E4M3FNUZ_16x16x32_F32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,23 @@ static void padConvOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp) {
int64_t nSize = bounds[nDim];
int64_t kSize = bounds[kDim];

auto inpElemType =
cast<ShapedType>(linalgOp.getDpsInputOperand(0)->get().getType())
.getElementType();
auto kernelElemType =
cast<ShapedType>(linalgOp.getDpsInputOperand(1)->get().getType())
.getElementType();

// TODO: Generalize to other dimensions.
// Try to search for pad value and check only filter dimension is blocked.
SmallVector<std::array<int64_t, 3>> mnkPaddingCandidates;
for (const GPUMatmulShapeType &intrinsic : intrinsics) {

if (!(inpElemType == intrinsic.aType &&
kernelElemType == intrinsic.bType)) {
continue;
}

std::optional<int64_t> mPadding, nPadding, kPadding;
auto getPadding = [](int64_t value, int64_t padTo) {
return llvm::divideCeil(value, padTo) * padTo - value;
Expand Down
28 changes: 28 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2289,6 +2289,34 @@ iree_generated_e2e_runner_test(
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_f32_large_cdna3_mfma
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f32"
"--acc_type=f32"
"--shapes=gpu_large_aligned"
"--compilation_info=LLVMGPUVectorDistributeMFMA"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"rocm"
DRIVERS
"hip"
COMPILER_FLAGS
${IREE_HIP_TEST_COMPILER_FLAGS}
LABELS
"noasan"
"nomsan"
"notsan"
"noubsan"
"requires-gpu-cdna3"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_rocm_f16_large_cdna3_mfma_tb
Expand Down
16 changes: 14 additions & 2 deletions tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ def get_rocm_test_compilation_infos(
schedules = []
if intrinsic == "MFMA":
schedules = [
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 1, 2, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 1, 1, 2, 1, 1),
MMASchedule("MFMA_F32_16x16x4_F32", 2, 2, 1, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 1),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 1, 2),
MMASchedule("MFMA_F16_16x16x16_F32", 1, 1, 1, 2, 1),
Expand Down Expand Up @@ -304,10 +309,17 @@ def get_rocm_test_compilation_infos(
for schedule in schedules:
# Skip schedules with an intrinsic which element type does not
# match the requested one.
if lhs_rhs_type.value.upper() not in schedule.intrinsic:
# Extracts the input type from strings containing either 'MFMA' or 'WMMA'
# followed by an underscore.
extract_input_type = lambda s: re.search(r"(?:MFMA|WMMA)_([^_]+)_", s).group(1)
if lhs_rhs_type.value.upper() != extract_input_type(schedule.intrinsic):
continue

if schedule.intrinsic == "MFMA_F16_16x16x16_F32":
if schedule.intrinsic == "MFMA_F32_16x16x4_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 4
elif schedule.intrinsic == "MFMA_F16_16x16x16_F32":
wg_tile_m = schedule.m_count * schedule.m_tile_count * 16
wg_tile_n = schedule.n_count * schedule.n_tile_count * 16
wg_tile_k = schedule.k_tile_count * 16
Expand Down

0 comments on commit cea7ec4

Please sign in to comment.