From 305e1ba168e47b45b7866090e554c1e3dbde7099 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 24 Oct 2024 00:18:58 +0000 Subject: [PATCH 1/5] [AMD] Add initial support for scaled_dot(mxfp8, fp8) --- .../Conversion/TritonGPUToLLVM/Utility.h | 8 + lib/Conversion/TritonGPUToLLVM/Utility.cpp | 18 +++ lib/Dialect/TritonGPU/IR/Ops.cpp | 7 - .../amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt | 1 + .../PatternTritonGPUOpToLLVM.h | 5 + .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 3 + .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 130 +++++++++++++++ .../AccelerateAMDMatmul.cpp | 148 +++++++++++++++++- .../UpcastMXFPToLLVM.cpp | 21 +-- 9 files changed, 313 insertions(+), 28 deletions(-) create mode 100644 third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 29b8865c03ae..a04c2470ca9f 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -391,6 +391,14 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); return base; } + +// ----------------------------------------------------------------------- +// MXFP utilities +// ----------------------------------------------------------------------- + +// Split bf16x2 into 2 bf16, scale each of them, and pack them back +Value mxfpScaleBf16x2(RewriterBase &rewriter, Location loc, Value v, + Value scale); } // namespace LLVM /* ------------------------------------ */ diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index e857dd36f6cb..610c5427ce5c 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -856,5 +856,23 @@ SmallVector getWrappedMultiDimOffset( return multiDimOffsetWrapped; } +Value mxfpScaleBf16x2(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + // Split bf16x2 into 2 bf16, scale each of them, and pack them back + // TODO Is it true that the bfloats are always packed as bf16x2? + auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); + auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); + auto scaleIsNan = icmp_eq(scale, i8_val(0xff)); + auto scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); + auto scaledBf16_0 = fmul(bf16_0, scaleBf16); + auto scaledBf16_1 = fmul(bf16_1, scaleBf16); + auto i16_0 = bitcast(scaledBf16_0, i16_ty); + auto i16_1 = bitcast(scaledBf16_1, i16_ty); + auto packed = or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); + // Account for NaN in the scale as per the mxfp specification + auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); + return packed_nan; +}; + } // namespace LLVM } // namespace mlir diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index e61fe096e10b..20cee1bf1f1b 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -62,13 +62,6 @@ LogicalResult UpcastMXFPOp::verify() { } auto blockedScale = cast(layoutScale); - // Necessary to keep all of the scales of a given block of values in the same - // warp - auto threadsPerWarp = blockedScale.getThreadsPerWarp(); - if (threadsPerWarp != ArrayRef({16, 2})) { - return emitOpError("Expected threads per warp to be {16, 2}"); - } - return success(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index b6a514f450cc..abd86dc03301 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -20,6 +20,7 @@ add_triton_library(TritonAMDGPUToLLVM OptimizeLDSUtility.cpp SPMDOpToLLVM.cpp SchedInstructions.cpp + UpcastMXFPToLLVM.cpp DEPENDS TritonAMDGPUConversionPassIncGen diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 764f31a610e1..1fdf3bdaa1cd 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -34,6 +34,11 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); +void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfo &targetInfo, + PatternBenefit benefit); + } // namespace mlir::triton::AMD #endif diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index aa71c92666f7..d227bb6c6a4b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -207,6 +207,8 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter, patterns, AMDBenefit); + mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns, + targetInfo, AMDBenefit); // TODO(thomas): this should probably be done in a separate step to not // interfere with our own lowering of arith ops. Add arith/math's patterns @@ -223,6 +225,7 @@ struct ConvertTritonAMDGPUToLLVM mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, commonBenefit); mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { return signalPassFailure(); } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp new file mode 100644 index 000000000000..422551d18781 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -0,0 +1,130 @@ +#include "PatternTritonGPUOpToLLVM.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::gpu; + +namespace { +class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { +private: + const TargetInfoBase &targetInfo; + +public: + UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { + } + + LogicalResult + matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto fpType = op.getFpType(); + if (!(fpType == F8F6F4Type::E4M3 || fpType == F8F6F4Type::E5M2)) + return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases"); + + Location loc = op.getLoc(); + auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); + auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter); + + // When we lower scaled dot op, we made sure to distribute K only on one + // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values + // along the K dimension. So in total each thread should read 32x main + // element values. Those fp16 values are packed in xVals to be 32-bit. + assert(xVals.size() == scaleVals.size() * 32 / 2); + + auto dotEncoding = + dyn_cast(op.getSrc().getType().getEncoding()); + auto mfmaEncoding = dyn_cast(dotEncoding.getParent()); + if (!mfmaEncoding) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand"); + if (dotEncoding.getOpIdx() == 1) + return rewriter.notifyMatchFailure(op, "NYI: dot RHS"); + + int mDim = mfmaEncoding.getMDim(); + if (mDim != 32 || mDim != 16) + return rewriter.notifyMatchFailure(op, "NYI: non-mfma32/16 intrinsics"); + + int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp( + op->getParentOfType()); + Value warpSize = i32_val(numThreads); + Value tid = tid_val(); + Value warpId = udiv(tid, warpSize); + Value laneId = urem(tid, warpSize); + + // Given that MFMA layout for the A tensor arranges thread in a column-major + // manner, for the current tid, it's at row (tid / mDim). When we set up + // blocked layout for the A scale tensor, we made sure that it has a + // threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values + // for the current thread starts at (tid / mDim * (64 / mDim)). + Value offset = mul(udiv(laneId, i32_val(mDim)), i32_val(numThreads / mDim)); + + if (mDim == 32) { + // One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we + // tile, the same warp owns the whole K dim. Inside a warp, each thread + // only holds 4 consecutive elements along K--a 1x4 vector. We need to + // tile the warp 4 times to cover 32 values along K. So for a thread, the + // first 4 1x4 vectors it holds shares the first scale value at row (tid / + // mDim). the second 4 1x4 vectors shares the second scale value at row + // (tid / mDim); and so forth. + std::array scaleThreads = {offset, add(offset, i32_val(1))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + std::array si = { + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + }; + + for (int j = 0; j < 16; ++j) { + xVals[16 * i + j] = LLVM::mxfpScaleBf16x2( + rewriter, loc, xVals[16 * i + j], si[j / 8]); + } + } + } else { + assert(mDim == 16); + // One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we + // need to tile the warp 2 times to cover 32 valeus. So for a thread, the + // first 2 1x4 vectors shares the first scale value at row (tid / mDim). + std::array scaleThreads = {offset, add(offset, i32_val(1)), + add(offset, i32_val(2)), + add(offset, i32_val(3))}; + + for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { + auto si = std::array{ + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[2]), + targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]), + }; + + for (int j = 0; j < 16; ++j) { + xVals[16 * i + j] = LLVM::mxfpScaleBf16x2( + rewriter, loc, xVals[16 * i + j], si[j / 4]); + } + } + } + + Value result = + packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // anonymous namespace + +void mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 6f93bfee99c7..3aa009c3639a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -160,6 +160,15 @@ FailureOr chooseMfmaInstruction(tt::DotOp dot, int mfmaVersion, aType.getShape().back(), mfmaVersion, nonKDim); } +FailureOr chooseMfmaInstruction(tt::DotScaledOp dot, int mfmaVersion, + int nonKDim) { + // For scaled dot, we handle it with bf16 emulation for now. + Type bf16Type = Builder(dot.getContext()).getBF16Type(); + return chooseMfmaInstruction( + dot.getC().getType(), /*aElemType=*/bf16Type, /*bElemType=*/bf16Type, + dot.getLhs().getType().getShape().back(), mfmaVersion, nonKDim); +} + using OperandTypesVector = SmallVector; OperandTypesVector selectMatrixCoreOperandTypes(tt::DotOp dot, @@ -469,6 +478,141 @@ class BlockedToMFMA : public OpRewritePattern { } }; +class ScaledBlockedToMFMA final : public OpRewritePattern { + int mfmaVersion; + int nonKDim; + int kPack; + +public: + ScaledBlockedToMFMA(MLIRContext *context, int mfmaVersion, int nonKDim, + int kPack, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), + nonKDim(nonKDim), kPack(kPack) {} + + LogicalResult matchAndRewrite(triton::DotScaledOp dotOp, + PatternRewriter &rewriter) const override { + using TensorValue = TypedValue; + + RankedTensorType oldRetType = dotOp.getType(); + if (!isa_and_nonnull(oldRetType.getEncoding())) + return rewriter.notifyMatchFailure( + dotOp, "expected blocked encoding result tensor"); + + if (dotOp.getRhsScale()) + return rewriter.notifyMatchFailure(dotOp, "NYI: RHS scale"); + + TensorValue a = dotOp.getLhs(); + TensorValue b = dotOp.getRhs(); + TensorValue aScale = dotOp.getLhsScale(); + F8F6F4Type aElemType = dotOp.getLhsType(); + F8F6F4Type bElemType = dotOp.getRhsType(); + + if (!(aElemType == F8F6F4Type::E4M3 || aElemType == F8F6F4Type::E5M2)) + return rewriter.notifyMatchFailure(dotOp, "NYI: non-mxfp8 LHS"); + if (!(bElemType == F8F6F4Type::E4M3 || bElemType == F8F6F4Type::E5M2)) + return rewriter.notifyMatchFailure(dotOp, "NYI: non-fp8 RHS"); + + MLIRContext *ctx = dotOp.getContext(); + auto moduleOp = dotOp->getParentOfType(); + + ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding()); + int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp); + int numThreads = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp); + + // Choose a suitable MFMA instruction for this scaled dot op. + FailureOr mfmaInstr = + chooseMfmaInstruction(dotOp, mfmaVersion, nonKDim); + if (failed(mfmaInstr)) + return rewriter.notifyMatchFailure(dotOp, "cannot choose mfma intrinsic"); + + unsigned mDim = mfmaInstr.value().getMDim(); + unsigned nDim = mfmaInstr.value().getNDim(); + unsigned kDim = mfmaInstr.value().getKDim(); + unsigned kBase = mfmaInstr.value().getKBase(); + unsigned kWdith = kBase *= kPack; + + // For A tensor, 32 consecutive elements along K dim share the same scale. + // We'd like to keep the scale values together with the base values in the + // same warp to avoid cross-warp data exchange. It means we want warpsPerCTA + // = 1 along the N dimension. + SmallVector warpsPerCTA(oldRetType.getRank(), 1); + warpsPerCTA.front() = numWarps; + + // Always use transposed mfma layout. This enables larger vectorization + // for global store instructions. + auto mfmaEnc = ttg::AMDMfmaEncodingAttr::get( + ctx, /*versionMajor=*/mfmaVersion, /*versionMinor=*/0, warpsPerCTA, + /*instrShape=*/mDim, nDim, /*isTransposed=*/true, ctaLayout); + + auto newRetType = RankedTensorType::get( + oldRetType.getShape(), oldRetType.getElementType(), mfmaEnc); + + auto newAcc = rewriter.create( + dotOp.getC().getLoc(), newRetType, dotOp.getC()); + + // OCP mxfp8 requires implementations to follow OCP fp8 elements. We are + // doing software emulation using bf16 here, so we map to OCP fp8 f8E4M3FN + // and f8E5M2. + auto enumToType = [&rewriter](F8F6F4Type type) { + switch (type) { + case F8F6F4Type::E4M3: + return rewriter.getFloat8E4M3FNType(); + case F8F6F4Type::E5M2: + return rewriter.getFloat8E5M2Type(); + default: + llvm_unreachable("unexpected fp type"); + } + }; + + auto toMMABf16 = [&](TensorValue v, int idx, + F8F6F4Type type) -> TensorValue { + assert(type == F8F6F4Type::E5M2 || type == F8F6F4Type::E4M3); + auto vType = v.getType(); + auto newVEnc = DotOperandEncodingAttr::get( + ctx, idx, newRetType.getEncoding(), kWdith); + auto newVType = RankedTensorType::get(vType.getShape(), + vType.getElementType(), newVEnc); + v = rewriter.create(v.getLoc(), newVType, v); + + auto vTypeFp8 = + RankedTensorType::get(vType.getShape(), enumToType(type), newVEnc); + v = cast( + rewriter.create(v.getLoc(), vTypeFp8, v).getResult()); + + auto vTypeBf16 = RankedTensorType::get(vType.getShape(), + rewriter.getBF16Type(), newVEnc); + return rewriter.create(v.getLoc(), vTypeBf16, v); + }; + a = toMMABf16(a, 0, aElemType); + b = toMMABf16(b, 1, bElemType); + + // We need to have "matching" encoding between the A tensor and A scale + // tensor to make sure the scale values needed is in the same warp. So we + // adopt the same CTA layout and warps per CTA. The warp dimensions needs to + // match along M dimension too. With in a warp, we have 64 threads. We let + // each thread read in one scale value. So we need a threadsPerWarp = mDim + // along M dimension. + SmallVector threadsPerWarp = {mDim, numThreads / mDim}; + auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + ctx, {1, 1}, threadsPerWarp, warpsPerCTA, {1, 0}, ctaLayout); + + auto newScaleType = RankedTensorType::get(aScale.getType().getShape(), + aScale.getType().getElementType(), + newScaleEncoding); + aScale = rewriter.create(aScale.getLoc(), + newScaleType, aScale); + + auto scaledA = rewriter.create( + dotOp.getLoc(), a, aScale, dotOp.getLhsType()); + + auto newDot = + rewriter.create(dotOp.getLoc(), newRetType, scaledA, b, newAcc); + rewriter.replaceOpWithNewOp(dotOp, oldRetType, + newDot); + return success(); + } +}; + static Value promoteOperand(OpBuilder &builder, Location loc, Value operand, Type promotedType) { Type tensorPromotedType = cast(operand.getType()) @@ -690,8 +834,8 @@ class TritonAMDGPUAccelerateMatmulPass case ISAFamily::CDNA1: case ISAFamily::CDNA2: case ISAFamily::CDNA3: - patterns.add<::BlockedToMFMA>(context, getMfmaVersion(isaFamily), - matrixInstructionSize, kPack); + patterns.add<::BlockedToMFMA, ::ScaledBlockedToMFMA>( + context, getMfmaVersion(isaFamily), matrixInstructionSize, kPack); break; case ISAFamily::RDNA3: patterns.add<::BlockedToWMMA>(context, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 722bf56cd015..19bb3792a4e8 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -107,24 +107,6 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { xVals = unpackFP4Elements(loc, rewriter, xVals, laneId); } - auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value { - // Split bf16x2 into 2 bf16, scale each of them, and pack them back - // TODO Is it true that the bfloats are always packed as bf16x2? - auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); - auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); - auto scaleIsNan = icmp_eq(s, i8_val(0xff)); - auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty); - auto scaledBf16_0 = fmul(bf16_0, scaleBf16); - auto scaledBf16_1 = fmul(bf16_1, scaleBf16); - auto i16_0 = bitcast(scaledBf16_0, i16_ty); - auto i16_1 = bitcast(scaledBf16_1, i16_ty); - auto packed = - or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); - // Account for NaN in the scale as per the mxfp specification - auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); - return packed_nan; - }; - // Each thread owns elements of 4 mxfp vectors so we need 4 scales // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + // 16, c + 17 @@ -142,7 +124,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { }; for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]); + xVals[16 * i + j] = + LLVM::mxfpScaleBf16x2(rewriter, loc, xVals[16 * i + j], si[j / 4]); } } From f1349dba8e2cdff04f09a9772736c8dc965ade4c Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 25 Oct 2024 20:16:53 +0000 Subject: [PATCH 2/5] Enable certain tests --- python/test/unit/language/test_core.py | 20 +++++++++------- .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 24 +++++++++++-------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 182445836e71..2aa80c00481d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3329,12 +3329,13 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid for type_a in ["e2m1", "e4m3", "e5m2"] for type_b in ["e4m3", "e5m2"]]) def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): - if not is_cuda(): - pytest.skip("scaled_dot only supported on CUDA") - else: + if is_cuda(): cc = torch.cuda.get_device_capability() if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") + if is_hip(): + if type_a != "e5m2" or type_b != "e5m2": + pytest.skip(f"{type_a} * {type_b} not yet implemented for HIP") @triton.jit def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, @@ -3503,12 +3504,13 @@ def make_finite(x, dtype): torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2) # make sure ld/st are vectorized - ptx = pgm.asm['ptx'] - if (max(M, N) * K) // (num_warps * 32) >= 4: - assert 'ld.global.v4' in ptx - if M * N // (num_warps * 32) >= 4: - assert 'st.global.v4' in ptx - assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) + if is_cuda(): + ptx = pgm.asm['ptx'] + if (max(M, N) * K) // (num_warps * 32) >= 4: + assert 'ld.global.v4' in ptx + if M * N // (num_warps * 32) >= 4: + assert 'st.global.v4' in ptx + assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx) @pytest.mark.interpreter diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index 422551d18781..ea0d0b39ca7b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -10,6 +10,7 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" #include using namespace mlir; @@ -37,6 +38,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Location loc = op.getLoc(); auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter); auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter); + LDBG("x: " << xVals.size() << " x " << xVals.front().getType()); + LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType()); // When we lower scaled dot op, we made sure to distribute K only on one // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values @@ -45,15 +48,16 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { assert(xVals.size() == scaleVals.size() * 32 / 2); auto dotEncoding = - dyn_cast(op.getSrc().getType().getEncoding()); + cast(op.getSrc().getType().getEncoding()); + if (dotEncoding.getOpIdx() == 1) + return rewriter.notifyMatchFailure(op, "NYI: dot RHS"); auto mfmaEncoding = dyn_cast(dotEncoding.getParent()); if (!mfmaEncoding) return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand"); - if (dotEncoding.getOpIdx() == 1) - return rewriter.notifyMatchFailure(op, "NYI: dot RHS"); + LDBG("mfma: " << mfmaEncoding); int mDim = mfmaEncoding.getMDim(); - if (mDim != 32 || mDim != 16) + if (mDim != 32 && mDim != 16) return rewriter.notifyMatchFailure(op, "NYI: non-mfma32/16 intrinsics"); int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp( @@ -64,20 +68,20 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { Value laneId = urem(tid, warpSize); // Given that MFMA layout for the A tensor arranges thread in a column-major - // manner, for the current tid, it's at row (tid / mDim). When we set up + // manner, for the current tid, it's at row (tid % mDim). When we set up // blocked layout for the A scale tensor, we made sure that it has a // threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values - // for the current thread starts at (tid / mDim * (64 / mDim)). - Value offset = mul(udiv(laneId, i32_val(mDim)), i32_val(numThreads / mDim)); + // for the current thread starts at ((tid % mDim) * (64 / mDim)). + Value offset = mul(urem(laneId, i32_val(mDim)), i32_val(numThreads / mDim)); if (mDim == 32) { // One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we // tile, the same warp owns the whole K dim. Inside a warp, each thread // only holds 4 consecutive elements along K--a 1x4 vector. We need to // tile the warp 4 times to cover 32 values along K. So for a thread, the - // first 4 1x4 vectors it holds shares the first scale value at row (tid / + // first 4 1x4 vectors it holds shares the first scale value at row (tid % // mDim). the second 4 1x4 vectors shares the second scale value at row - // (tid / mDim); and so forth. + // (tid % mDim); and so forth. std::array scaleThreads = {offset, add(offset, i32_val(1))}; for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) { @@ -95,7 +99,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { assert(mDim == 16); // One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we // need to tile the warp 2 times to cover 32 valeus. So for a thread, the - // first 2 1x4 vectors shares the first scale value at row (tid / mDim). + // first 2 1x4 vectors shares the first scale value at row (tid % mDim). std::array scaleThreads = {offset, add(offset, i32_val(1)), add(offset, i32_val(2)), add(offset, i32_val(3))}; From 27d403bfc6a5ce6bddd53591805c9fca65608f9f Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 25 Oct 2024 22:31:01 +0000 Subject: [PATCH 3/5] Drop packing logic given we process standalone elements --- .../Conversion/TritonGPUToLLVM/Utility.h | 8 ------ lib/Conversion/TritonGPUToLLVM/Utility.cpp | 18 ------------- .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 27 +++++++++++++------ .../UpcastMXFPToLLVM.cpp | 21 +++++++++++++-- 4 files changed, 38 insertions(+), 36 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index a04c2470ca9f..29b8865c03ae 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -391,14 +391,6 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); return base; } - -// ----------------------------------------------------------------------- -// MXFP utilities -// ----------------------------------------------------------------------- - -// Split bf16x2 into 2 bf16, scale each of them, and pack them back -Value mxfpScaleBf16x2(RewriterBase &rewriter, Location loc, Value v, - Value scale); } // namespace LLVM /* ------------------------------------ */ diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 610c5427ce5c..e857dd36f6cb 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -856,23 +856,5 @@ SmallVector getWrappedMultiDimOffset( return multiDimOffsetWrapped; } -Value mxfpScaleBf16x2(RewriterBase &rewriter, Location loc, Value v, - Value scale) { - // Split bf16x2 into 2 bf16, scale each of them, and pack them back - // TODO Is it true that the bfloats are always packed as bf16x2? - auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); - auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); - auto scaleIsNan = icmp_eq(scale, i8_val(0xff)); - auto scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); - auto scaledBf16_0 = fmul(bf16_0, scaleBf16); - auto scaledBf16_1 = fmul(bf16_1, scaleBf16); - auto i16_0 = bitcast(scaledBf16_0, i16_ty); - auto i16_1 = bitcast(scaledBf16_1, i16_ty); - auto packed = or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); - // Account for NaN in the scale as per the mxfp specification - auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); - return packed_nan; -}; - } // namespace LLVM } // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index ea0d0b39ca7b..4e43e71fc434 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -18,6 +18,17 @@ using namespace mlir::triton; using namespace mlir::triton::gpu; namespace { + +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); + Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); + Value scaledBf16 = fmul(v, scaleBf16); + // Account for NaN in the scale as per the mxfp specification. + return select(scaleIsNan, nanBf16, scaledBf16); +}; + class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { private: const TargetInfoBase &targetInfo; @@ -44,8 +55,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { // When we lower scaled dot op, we made sure to distribute K only on one // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values // along the K dimension. So in total each thread should read 32x main - // element values. Those fp16 values are packed in xVals to be 32-bit. - assert(xVals.size() == scaleVals.size() * 32 / 2); + // element values. + assert(xVals.size() == scaleVals.size() * 32); auto dotEncoding = cast(op.getSrc().getType().getEncoding()); @@ -90,9 +101,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), }; - for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = LLVM::mxfpScaleBf16x2( - rewriter, loc, xVals[16 * i + j], si[j / 8]); + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); } } } else { @@ -112,9 +123,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]), }; - for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = LLVM::mxfpScaleBf16x2( - rewriter, loc, xVals[16 * i + j], si[j / 4]); + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); } } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 19bb3792a4e8..722bf56cd015 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -107,6 +107,24 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { xVals = unpackFP4Elements(loc, rewriter, xVals, laneId); } + auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value { + // Split bf16x2 into 2 bf16, scale each of them, and pack them back + // TODO Is it true that the bfloats are always packed as bf16x2? + auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); + auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); + auto scaleIsNan = icmp_eq(s, i8_val(0xff)); + auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty); + auto scaledBf16_0 = fmul(bf16_0, scaleBf16); + auto scaledBf16_1 = fmul(bf16_1, scaleBf16); + auto i16_0 = bitcast(scaledBf16_0, i16_ty); + auto i16_1 = bitcast(scaledBf16_1, i16_ty); + auto packed = + or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); + // Account for NaN in the scale as per the mxfp specification + auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); + return packed_nan; + }; + // Each thread owns elements of 4 mxfp vectors so we need 4 scales // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + // 16, c + 17 @@ -124,8 +142,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { }; for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = - LLVM::mxfpScaleBf16x2(rewriter, loc, xVals[16 * i + j], si[j / 4]); + xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]); } } From 88d273924428c84377a53eda646acc19a3fd1428 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 25 Oct 2024 22:40:53 +0000 Subject: [PATCH 4/5] Adjust UpcastMXFPOp op verification --- lib/Dialect/TritonGPU/IR/Ops.cpp | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 20cee1bf1f1b..b5499ea5a5ab 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -52,15 +52,26 @@ LogicalResult UpcastMXFPOp::verify() { "all dimensions except the last must match between operands"); } - auto layoutX = xTy.getEncoding(); - if (!layoutX || !isa(layoutX)) { + auto dotEncoding = + dyn_cast_or_null(xTy.getEncoding()); + if (!dotEncoding) { return emitOpError("Expected a DotOperandEncodingAttr for values"); } - auto layoutScale = scaleTy.getEncoding(); - if (!layoutScale || !isa(layoutScale)) { + + auto blockedScale = + dyn_cast_or_null(scaleTy.getEncoding()); + if (!blockedScale) { return emitOpError("Expected a BlockOperandEncoding for scales"); } - auto blockedScale = cast(layoutScale); + + if (isa(dotEncoding.getParent())) { + // Necessary to keep all of the scales of a given block of values in the + // same warp + auto threadsPerWarp = blockedScale.getThreadsPerWarp(); + if (threadsPerWarp != ArrayRef({16, 2})) { + return emitOpError("Expected threads per warp to be {16, 2}"); + } + } return success(); } From 06e19822a3b89c25a77d0a2be45d928629456bcc Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sat, 26 Oct 2024 00:04:32 +0000 Subject: [PATCH 5/5] Test more configurations --- python/test/unit/language/test_core.py | 21 ++++++++++++------- .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 3 ++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 2aa80c00481d..1cebd2577969 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3322,20 +3322,24 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx -@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps", - [(M, N, K, col_a, col_b, type_a, type_b, 4) +@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack", + [(M, N, K, col_a, col_b, type_a, type_b, 4, mma, kpack) for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128]) for col_a, col_b in itertools.product([True, False], repeat=2) for type_a in ["e2m1", "e4m3", "e5m2"] - for type_b in ["e4m3", "e5m2"]]) -def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device): + for type_b in ["e4m3", "e5m2"] + for mma in ([32, 16] if is_hip() else [16]) + for kpack in ([1, 2] if is_hip() else [1])]) +def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device): if is_cuda(): cc = torch.cuda.get_device_capability() if cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") if is_hip(): if type_a != "e5m2" or type_b != "e5m2": - pytest.skip(f"{type_a} * {type_b} not yet implemented for HIP") + pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP") + if mma == 16 and K == 64: + pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot") @triton.jit def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out, @@ -3494,9 +3498,12 @@ def make_finite(x, dtype): x = make_finite(x, type_a) y = make_finite(y, type_b) + kernel_kwargs = {"num_warps": num_warps} + if is_hip(): + kernel_kwargs["kpack"] = kpack + kernel_kwargs["matrix_instr_nonkdim"] = mma z = x.new_empty((M, N), dtype=torch.bfloat16) - pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, - num_warps=num_warps) + pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, **kernel_kwargs) z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index 4e43e71fc434..289ceb61a51b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -56,7 +56,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values // along the K dimension. So in total each thread should read 32x main // element values. - assert(xVals.size() == scaleVals.size() * 32); + if (xVals.size() != scaleVals.size() * 32) + return rewriter.notifyMatchFailure(op, "unsupported problem size"); auto dotEncoding = cast(op.getSrc().getType().getEncoding());