diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 65199ea8c281..465ec7267a17 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3337,7 +3337,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") if is_hip(): - if type_a != "e5m2" or type_b != "e5m2": + if type_a != "e5m2" or (type_b != "e5m2" and type_b != "bf16"): pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP") if mma == 16 and K == 64: pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 201a7b0212fe..03fb8c3c6e62 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -511,8 +511,9 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { aElemType == ScaleDotElemType::E5M2)) return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8 LHS"); if (!(bElemType == ScaleDotElemType::E4M3 || - bElemType == ScaleDotElemType::E5M2)) - return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8 RHS"); + bElemType == ScaleDotElemType::E5M2 || + bElemType == ScaleDotElemType::BF16)) + return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8/bf16 RHS"); MLIRContext *ctx = dotOp.getContext(); auto moduleOp = dotOp->getParentOfType(); @@ -568,21 +569,25 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { auto toMMABf16 = [&](TensorValue v, int idx, ScaleDotElemType type) -> TensorValue { - assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3); + assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3 || + type == ScaleDotElemType::BF16); + auto vType = v.getType(); - auto newVEnc = DotOperandEncodingAttr::get( + auto newVEncoding = DotOperandEncodingAttr::get( ctx, idx, newRetType.getEncoding(), kWdith); - auto newVType = RankedTensorType::get(vType.getShape(), - vType.getElementType(), newVEnc); + auto newVType = RankedTensorType::get( + vType.getShape(), vType.getElementType(), newVEncoding); v = rewriter.create(v.getLoc(), newVType, v); + if (type == ScaleDotElemType::BF16) + return v; - auto vTypeFp8 = - RankedTensorType::get(vType.getShape(), enumToType(type), newVEnc); + auto vTypeFp8 = RankedTensorType::get(vType.getShape(), enumToType(type), + newVEncoding); v = cast( rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); - auto vTypeBf16 = RankedTensorType::get(vType.getShape(), - rewriter.getBF16Type(), newVEnc); + auto vTypeBf16 = RankedTensorType::get( + vType.getShape(), rewriter.getBF16Type(), newVEncoding); return rewriter.create(v.getLoc(), vTypeBf16, v); }; a = toMMABf16(a, 0, aElemType);