Skip to content

Commit

Permalink
[AMD] Enable scaled_dot(-, bf16)
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst committed Oct 31, 2024
1 parent ee5876c commit b8346ba
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
2 changes: 1 addition & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3337,7 +3337,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
if cc < (8, 9):
pytest.skip("float8e4nv not supported on CUDA < 8.9")
if is_hip():
if type_a != "e5m2" or type_b != "e5m2":
if type_a != "e5m2" or (type_b != "e5m2" and type_b != "bf16"):
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
Expand Down
23 changes: 14 additions & 9 deletions third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,8 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
aElemType == ScaleDotElemType::E5M2))
return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8 LHS");
if (!(bElemType == ScaleDotElemType::E4M3 ||
bElemType == ScaleDotElemType::E5M2))
bElemType == ScaleDotElemType::E5M2 ||
bElemType == ScaleDotElemType::BF16))
return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8 RHS");

MLIRContext *ctx = dotOp.getContext();
Expand Down Expand Up @@ -568,21 +569,25 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {

auto toMMABf16 = [&](TensorValue v, int idx,
ScaleDotElemType type) -> TensorValue {
assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3);
assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3 ||
type == ScaleDotElemType::BF16);

auto vType = v.getType();
auto newVEnc = DotOperandEncodingAttr::get(
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, idx, newRetType.getEncoding(), kWdith);
auto newVType = RankedTensorType::get(vType.getShape(),
vType.getElementType(), newVEnc);
auto newVType = RankedTensorType::get(
vType.getShape(), vType.getElementType(), newVEncoding);
v = rewriter.create<ttg::ConvertLayoutOp>(v.getLoc(), newVType, v);
if (type == ScaleDotElemType::BF16)
return v;

auto vTypeFp8 =
RankedTensorType::get(vType.getShape(), enumToType(type), newVEnc);
auto vTypeFp8 = RankedTensorType::get(vType.getShape(), enumToType(type),
newVEncoding);
v = cast<TensorValue>(
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());

auto vTypeBf16 = RankedTensorType::get(vType.getShape(),
rewriter.getBF16Type(), newVEnc);
auto vTypeBf16 = RankedTensorType::get(
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
};
a = toMMABf16(a, 0, aElemType);
Expand Down

0 comments on commit b8346ba

Please sign in to comment.