diff --git a/mlir/lib/Dialect/Rock/utility/AmdArchDb.cpp b/mlir/lib/Dialect/Rock/utility/AmdArchDb.cpp index b87bac772f14..7dd482b2e965 100644 --- a/mlir/lib/Dialect/Rock/utility/AmdArchDb.cpp +++ b/mlir/lib/Dialect/Rock/utility/AmdArchDb.cpp @@ -111,8 +111,10 @@ GemmFeatures mlir::rock::AmdArchInfo::getDefaultFeatures(Type dataType) { bool isWmma = bitEnumContainsAll(theseFeatures, GemmFeatures::wmma); Type elementType = getElementTypeOrSelf(dataType); if (isWmma) { - if (!elementType.isF16() && !elementType.isBF16() && - !elementType.isInteger(8)) { + if (!(isa(elementType) || + elementType.isInteger(8) || + (hasFp8ConversionInstrs && + isa(elementType)))) { theseFeatures = bitEnumClear(theseFeatures, GemmFeatures::wmma); } } diff --git a/mlir/test/rocmlir-gen/wmma-enablement.mlir b/mlir/test/rocmlir-gen/wmma-enablement.mlir new file mode 100644 index 000000000000..ea033c10b691 --- /dev/null +++ b/mlir/test/rocmlir-gen/wmma-enablement.mlir @@ -0,0 +1,12 @@ +// RUN: rocmlir-gen --arch gfx1201 --operation gemm --operation gemm -wmma infer -t f16 -p | grep '|wmma' | count 1 +// RUN: rocmlir-gen --arch gfx1201 --operation gemm -wmma infer -t fp8_fp8 -p | grep '|wmma' | count 1 +// RUN: rocmlir-gen --arch gfx1201 --operation gemm -wmma infer -t bf8_bf8 -p | grep '|wmma' | count 1 +// RUN: rocmlir-gen --arch gfx1201 --operation gemm -wmma infer -t fp8_fp8 -force-f8-types=fnuz -p | not grep '|wmma' + +// RUN: rocmlir-gen --arch gfx1100 --operation gemm -wmma infer -t f16 -p | grep '|wmma' | count 1 +// RUN: rocmlir-gen --arch gfx1100 --operation gemm -wmma infer -t fp8_fp8 -p | not grep '|wmma' + +// YES: rock.gemm +// YES-SAME: features = {{[^ ]*}}wmma +// NO: rock.gemm +// NO-NOT: wmma