Skip to content

Commit

Permalink
[AMD] Enable scaled_dot(-, bf16) (triton-lang#5029)
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst authored Nov 1, 2024
1 parent ee5876c commit f062540
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3337,7 +3337,7 @@ def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack
if cc < (8, 9):
pytest.skip("float8e4nv not supported on CUDA < 8.9")
if is_hip():
if type_a != "e5m2" or type_b != "e5m2":
if type_a != "e5m2" or (type_b != "e5m2" and type_b != "bf16"):
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
Expand Down
25 changes: 15 additions & 10 deletions third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,9 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
aElemType == ScaleDotElemType::E5M2))
return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8 LHS");
if (!(bElemType == ScaleDotElemType::E4M3 ||
bElemType == ScaleDotElemType::E5M2))
return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8 RHS");
bElemType == ScaleDotElemType::E5M2 ||
bElemType == ScaleDotElemType::BF16))
return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8/bf16 RHS");

MLIRContext *ctx = dotOp.getContext();
auto moduleOp = dotOp->getParentOfType<ModuleOp>();
Expand Down Expand Up @@ -568,21 +569,25 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {

auto toMMABf16 = [&](TensorValue v, int idx,
ScaleDotElemType type) -> TensorValue {
assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3);
assert(type == ScaleDotElemType::E5M2 || type == ScaleDotElemType::E4M3 ||
type == ScaleDotElemType::BF16);

auto vType = v.getType();
auto newVEnc = DotOperandEncodingAttr::get(
auto newVEncoding = DotOperandEncodingAttr::get(
ctx, idx, newRetType.getEncoding(), kWdith);
auto newVType = RankedTensorType::get(vType.getShape(),
vType.getElementType(), newVEnc);
auto newVType = RankedTensorType::get(
vType.getShape(), vType.getElementType(), newVEncoding);
v = rewriter.create<ttg::ConvertLayoutOp>(v.getLoc(), newVType, v);
if (type == ScaleDotElemType::BF16)
return v;

auto vTypeFp8 =
RankedTensorType::get(vType.getShape(), enumToType(type), newVEnc);
auto vTypeFp8 = RankedTensorType::get(vType.getShape(), enumToType(type),
newVEncoding);
v = cast<TensorValue>(
rewriter.create<BitcastOp>(v.getLoc(), vTypeFp8, v).getResult());

auto vTypeBf16 = RankedTensorType::get(vType.getShape(),
rewriter.getBF16Type(), newVEnc);
auto vTypeBf16 = RankedTensorType::get(
vType.getShape(), rewriter.getBF16Type(), newVEncoding);
return rewriter.create<FpToFpOp>(v.getLoc(), vTypeBf16, v);
};
a = toMMABf16(a, 0, aElemType);
Expand Down

0 comments on commit f062540

Please sign in to comment.