diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index e61fe096e10b..b5499ea5a5ab 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -52,21 +52,25 @@ LogicalResult UpcastMXFPOp::verify() { "all dimensions except the last must match between operands"); } - auto layoutX = xTy.getEncoding(); - if (!layoutX || !isa(layoutX)) { + auto dotEncoding = + dyn_cast_or_null(xTy.getEncoding()); + if (!dotEncoding) { return emitOpError("Expected a DotOperandEncodingAttr for values"); } - auto layoutScale = scaleTy.getEncoding(); - if (!layoutScale || !isa(layoutScale)) { + + auto blockedScale = + dyn_cast_or_null(scaleTy.getEncoding()); + if (!blockedScale) { return emitOpError("Expected a BlockOperandEncoding for scales"); } - auto blockedScale = cast(layoutScale); - // Necessary to keep all of the scales of a given block of values in the same - // warp - auto threadsPerWarp = blockedScale.getThreadsPerWarp(); - if (threadsPerWarp != ArrayRef({16, 2})) { - return emitOpError("Expected threads per warp to be {16, 2}"); + if (isa(dotEncoding.getParent())) { + // Necessary to keep all of the scales of a given block of values in the + // same warp + auto threadsPerWarp = blockedScale.getThreadsPerWarp(); + if (threadsPerWarp != ArrayRef({16, 2})) { + return emitOpError("Expected threads per warp to be {16, 2}"); + } } return success(); diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 182445836e71..1cebd2577969 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3322,19 +3322,24 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx -@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", - [(M, N, K, col_a, col_b, type_a, type_b, 4) +@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack", + [(M, N, K, col_a, col_b, type_a, type_b, 4, mma, kpack) for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) for col_a, col_b in itertools.product([True, False], repeat=2) for type_a in ["e2m1", "e4m3", "e5m2"] - for type_b in ["e4m3", "e5m2"]]) -def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): - if not is_cuda(): - pytest.skip("scaled_dot only supported on CUDA") - else: + for type_b in ["e4m3", "e5m2"] + for mma in ([32, 16] if is_hip() else [16]) + for kpack in ([1, 2] if is_hip() else [1])]) +def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device): + if is_cuda(): cc = torch.cuda.get_device_capability() if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") + if is_hip(): + if type_a != "e5m2" or type_b != "e5m2": + pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP") + if mma == 16 and K == 64: + pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") @triton.jit def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, @@ -3493,9 +3498,12 @@ def make_finite(x, dtype): x = make_finite(x, type_a) y = make_finite(y, type_b) + kernel_kwargs = {"num_warps": num_warps} + if is_hip(): + kernel_kwargs["kpack"] = kpack + kernel_kwargs["matrix_instr_nonkdim"] = mma z = x.new_empty((M, N), dtype=torch.bfloat16) - pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, - num_warps=num_warps) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, **kernel_kwargs) z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) @@ -3503,12 +3511,13 @@ def make_finite(x, dtype): torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) # make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - if (max(M, N) * K) // (num_warps * 32) >= 4: - assert 'ld.global.v4' in ptx - if M * N // (num_warps * 32) >= 4: - assert 'st.global.v4' in ptx - assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) + if is_cuda(): + ptx = pgm.asm['ptx'] + if (max(M, N) * K) // (num_warps * 32) >= 4: + assert 'ld.global.v4' in ptx + if M * N // (num_warps * 32) >= 4: + assert 'st.global.v4' in ptx + assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) @pytest.mark.interpreter diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index b6a514f450cc..abd86dc03301 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -20,6 +20,7 @@ add_triton_library(TritonAMDGPUToLLVM OptimizeLDSUtility.cpp SPMDOpToLLVM.cpp SchedInstructions.cpp + UpcastMXFPToLLVM.cpp DEPENDS TritonAMDGPUConversionPassIncGen diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 764f31a610e1..1fdf3bdaa1cd 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -34,6 +34,11 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + } // namespace mlir::triton::AMD #endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index aa71c92666f7..d227bb6c6a4b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -207,6 +207,8 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter, patterns, AMDBenefit); + mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns, + targetInfo, AMDBenefit); // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns @@ -223,6 +225,7 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { return signalPassFailure(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp new file mode 100644 index 000000000000..289ceb61a51b --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -0,0 +1,146 @@ +#include "PatternTritonGPUOpToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { + +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); + Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); + Value scaledBf16 = fmul(v, scaleBf16); + // Account for NaN in the scale as per the mxfp specification. + return select(scaleIsNan, nanBf16, scaledBf16); +}; + +class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { +private: + const TargetInfoBase &targetInfo; + +public: + UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto fpType = op.getFpType(); + if (!(fpType == F8F6F4Type::E4M3 || fpType == F8F6F4Type::E5M2)) + return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases"); + + Location loc = op.getLoc(); + auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter); + LDBG("x: " << xVals.size() << " x " << xVals.front().getType()); + LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType()); + + // When we lower scaled dot op, we made sure to distribute K only on one + // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values + // along the K dimension. So in total each thread should read 32x main + // element values. + if (xVals.size() != scaleVals.size() * 32) + return rewriter.notifyMatchFailure(op, "unsupported problem size"); + + auto dotEncoding = + cast(op.getSrc().getType().getEncoding()); + if (dotEncoding.getOpIdx() == 1) + return rewriter.notifyMatchFailure(op, "NYI: dot RHS"); + auto mfmaEncoding = dyn_cast(dotEncoding.getParent()); + if (!mfmaEncoding) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand"); + LDBG("mfma: " << mfmaEncoding); + + int mDim = mfmaEncoding.getMDim(); + if (mDim != 32 && mDim != 16) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma32/16 intrinsics"); + + int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + Value warpSize = i32_val(numThreads); + Value tid = tid_val(); + Value warpId = udiv(tid, warpSize); + Value laneId = urem(tid, warpSize); + + // Given that MFMA layout for the A tensor arranges thread in a column-major + // manner, for the current tid, it's at row (tid % mDim). When we set up + // blocked layout for the A scale tensor, we made sure that it has a + // threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values + // for the current thread starts at ((tid % mDim) * (64 / mDim)). + Value offset = mul(urem(laneId, i32_val(mDim)), i32_val(numThreads / mDim)); + + if (mDim == 32) { + // One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we + // tile, the same warp owns the whole K dim. Inside a warp, each thread + // only holds 4 consecutive elements along K--a 1x4 vector. We need to + // tile the warp 4 times to cover 32 values along K. So for a thread, the + // first 4 1x4 vectors it holds shares the first scale value at row (tid % + // mDim). the second 4 1x4 vectors shares the second scale value at row + // (tid % mDim); and so forth. + std::array scaleThreads = {offset, add(offset, i32_val(1))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + std::array si = { + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + }; + + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); + } + } + } else { + assert(mDim == 16); + // One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we + // need to tile the warp 2 times to cover 32 valeus. So for a thread, the + // first 2 1x4 vectors shares the first scale value at row (tid % mDim). + std::array scaleThreads = {offset, add(offset, i32_val(1)), + add(offset, i32_val(2)), + add(offset, i32_val(3))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + auto si = std::array{ + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[2]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]), + }; + + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); + } + } + } + + Value result = + packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // anonymous namespace + +void mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 6f93bfee99c7..3aa009c3639a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -160,6 +160,15 @@ FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, aType.getShape().back(), mfmaVersion, nonKDim); } +FailureOr chooseMfmaInstruction(tt::DotScaledOp dot, int mfmaVersion, + int nonKDim) { + // For scaled dot, we handle it with bf16 emulation for now. + Type bf16Type = Builder(dot.getContext()).getBF16Type(); + return chooseMfmaInstruction( + dot.getC().getType(), /*aElemType=*/bf16Type, /*bElemType=*/bf16Type, + dot.getLhs().getType().getShape().back(), mfmaVersion, nonKDim); +} + using OperandTypesVector = SmallVector; OperandTypesVector selectMatrixCoreOperandTypes(tt::DotOp dot, @@ -469,6 +478,141 @@ class BlockedToMFMA : public OpRewritePattern { } }; +class ScaledBlockedToMFMA final : public OpRewritePattern { + int mfmaVersion; + int nonKDim; + int kPack; + +public: + ScaledBlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, + int kPack, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} + + LogicalResult matchAndRewrite(triton::DotScaledOp dotOp, + PatternRewriter &rewriter) const override { + using TensorValue = TypedValue; + + RankedTensorType oldRetType = dotOp.getType(); + if (!isa_and_nonnull(oldRetType.getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); + + if (dotOp.getRhsScale()) + return rewriter.notifyMatchFailure(dotOp, "NYI: RHS scale"); + + TensorValue a = dotOp.getLhs(); + TensorValue b = dotOp.getRhs(); + TensorValue aScale = dotOp.getLhsScale(); + F8F6F4Type aElemType = dotOp.getLhsType(); + F8F6F4Type bElemType = dotOp.getRhsType(); + + if (!(aElemType == F8F6F4Type::E4M3 || aElemType == F8F6F4Type::E5M2)) + return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8 LHS"); + if (!(bElemType == F8F6F4Type::E4M3 || bElemType == F8F6F4Type::E5M2)) + return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8 RHS"); + + MLIRContext *ctx = dotOp.getContext(); + auto moduleOp = dotOp->getParentOfType(); + + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding()); + int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp); + int numThreads = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); + + // Choose a suitable MFMA instruction for this scaled dot op. + FailureOr mfmaInstr = + chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); + if (failed(mfmaInstr)) + return rewriter.notifyMatchFailure(dotOp, "cannot choose mfma intrinsic"); + + unsigned mDim = mfmaInstr.value().getMDim(); + unsigned nDim = mfmaInstr.value().getNDim(); + unsigned kDim = mfmaInstr.value().getKDim(); + unsigned kBase = mfmaInstr.value().getKBase(); + unsigned kWdith = kBase *= kPack; + + // For A tensor, 32 consecutive elements along K dim share the same scale. + // We'd like to keep the scale values together with the base values in the + // same warp to avoid cross-warp data exchange. It means we want warpsPerCTA + // = 1 along the N dimension. + SmallVector warpsPerCTA(oldRetType.getRank(), 1); + warpsPerCTA.front() = numWarps; + + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions. + auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get( + ctx, /*versionMajor=*/mfmaVersion, /*versionMinor=*/0, warpsPerCTA, + /*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout); + + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mfmaEnc); + + auto newAcc = rewriter.create( + dotOp.getC().getLoc(), newRetType, dotOp.getC()); + + // OCP mxfp8 requires implementations to follow OCP fp8 elements. We are + // doing software emulation using bf16 here, so we map to OCP fp8 f8E4M3FN + // and f8E5M2. + auto enumToType = [&rewriter](F8F6F4Type type) { + switch (type) { + case F8F6F4Type::E4M3: + return rewriter.getFloat8E4M3FNType(); + case F8F6F4Type::E5M2: + return rewriter.getFloat8E5M2Type(); + default: + llvm_unreachable("unexpected fp type"); + } + }; + + auto toMMABf16 = [&](TensorValue v, int idx, + F8F6F4Type type) -> TensorValue { + assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); + auto vType = v.getType(); + auto newVEnc = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), kWdith); + auto newVType = RankedTensorType::get(vType.getShape(), + vType.getElementType(), newVEnc); + v = rewriter.create(v.getLoc(), newVType, v); + + auto vTypeFp8 = + RankedTensorType::get(vType.getShape(), enumToType(type), newVEnc); + v = cast( + rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); + + auto vTypeBf16 = RankedTensorType::get(vType.getShape(), + rewriter.getBF16Type(), newVEnc); + return rewriter.create(v.getLoc(), vTypeBf16, v); + }; + a = toMMABf16(a, 0, aElemType); + b = toMMABf16(b, 1, bElemType); + + // We need to have "matching" encoding between the A tensor and A scale + // tensor to make sure the scale values needed is in the same warp. So we + // adopt the same CTA layout and warps per CTA. The warp dimensions needs to + // match along M dimension too. With in a warp, we have 64 threads. We let + // each thread read in one scale value. So we need a threadsPerWarp = mDim + // along M dimension. + SmallVector threadsPerWarp = {mDim, numThreads / mDim}; + auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, ctaLayout); + + auto newScaleType = RankedTensorType::get(aScale.getType().getShape(), + aScale.getType().getElementType(), + newScaleEncoding); + aScale = rewriter.create(aScale.getLoc(), + newScaleType, aScale); + + auto scaledA = rewriter.create( + dotOp.getLoc(), a, aScale, dotOp.getLhsType()); + + auto newDot = + rewriter.create(dotOp.getLoc(), newRetType, scaledA, b, newAcc); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot); + return success(); + } +}; + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, Type promotedType) { Type tensorPromotedType = cast(operand.getType()) @@ -690,8 +834,8 @@ class TritonAMDGPUAccelerateMatmulPass case ISAFamily::CDNA1: case ISAFamily::CDNA2: case ISAFamily::CDNA3: - patterns.add<::BlockedToMFMA>(context, getMfmaVersion(isaFamily), - matrixInstructionSize, kPack); + patterns.add<::BlockedToMFMA, ::ScaledBlockedToMFMA>( + context, getMfmaVersion(isaFamily), matrixInstructionSize, kPack); break; case ISAFamily::RDNA3: patterns.add<::BlockedToWMMA>(context,