From 27d403bfc6a5ce6bddd53591805c9fca65608f9f Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 25 Oct 2024 22:31:01 +0000 Subject: [PATCH] Drop packing logic given we process standalone elements --- .../Conversion/TritonGPUToLLVM/Utility.h | 8 ------ lib/Conversion/TritonGPUToLLVM/Utility.cpp | 18 ------------- .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 27 +++++++++++++------ .../UpcastMXFPToLLVM.cpp | 21 +++++++++++++-- 4 files changed, 38 insertions(+), 36 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index a04c2470ca9f..29b8865c03ae 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -391,14 +391,6 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); return base; } - -// ----------------------------------------------------------------------- -// MXFP utilities -// ----------------------------------------------------------------------- - -// Split bf16x2 into 2 bf16, scale each of them, and pack them back -Value mxfpScaleBf16x2(RewriterBase &rewriter, Location loc, Value v, - Value scale); } // namespace LLVM /* ------------------------------------ */ diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 610c5427ce5c..e857dd36f6cb 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -856,23 +856,5 @@ SmallVector getWrappedMultiDimOffset( return multiDimOffsetWrapped; } -Value mxfpScaleBf16x2(RewriterBase &rewriter, Location loc, Value v, - Value scale) { - // Split bf16x2 into 2 bf16, scale each of them, and pack them back - // TODO Is it true that the bfloats are always packed as bf16x2? - auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); - auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); - auto scaleIsNan = icmp_eq(scale, i8_val(0xff)); - auto scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); - auto scaledBf16_0 = fmul(bf16_0, scaleBf16); - auto scaledBf16_1 = fmul(bf16_1, scaleBf16); - auto i16_0 = bitcast(scaledBf16_0, i16_ty); - auto i16_1 = bitcast(scaledBf16_1, i16_ty); - auto packed = or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); - // Account for NaN in the scale as per the mxfp specification - auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); - return packed_nan; -}; - } // namespace LLVM } // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index ea0d0b39ca7b..4e43e71fc434 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -18,6 +18,17 @@ using namespace mlir::triton; using namespace mlir::triton::gpu; namespace { + +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); + Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); + Value scaledBf16 = fmul(v, scaleBf16); + // Account for NaN in the scale as per the mxfp specification. + return select(scaleIsNan, nanBf16, scaledBf16); +}; + class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { private: const TargetInfoBase &targetInfo; @@ -44,8 +55,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { // When we lower scaled dot op, we made sure to distribute K only on one // warp. MXFP spec mandates 1 scale value for every 32 onsecutive values // along the K dimension. So in total each thread should read 32x main - // element values. Those fp16 values are packed in xVals to be 32-bit. - assert(xVals.size() == scaleVals.size() * 32 / 2); + // element values. + assert(xVals.size() == scaleVals.size() * 32); auto dotEncoding = cast(op.getSrc().getType().getEncoding()); @@ -90,9 +101,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]), }; - for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = LLVM::mxfpScaleBf16x2( - rewriter, loc, xVals[16 * i + j], si[j / 8]); + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); } } } else { @@ -112,9 +123,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]), }; - for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = LLVM::mxfpScaleBf16x2( - rewriter, loc, xVals[16 * i + j], si[j / 4]); + for (int j = 0; j < 32; ++j) { + int index = 32 * i + j; + xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); } } } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 19bb3792a4e8..722bf56cd015 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -107,6 +107,24 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { xVals = unpackFP4Elements(loc, rewriter, xVals, laneId); } + auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value { + // Split bf16x2 into 2 bf16, scale each of them, and pack them back + // TODO Is it true that the bfloats are always packed as bf16x2? + auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty); + auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty); + auto scaleIsNan = icmp_eq(s, i8_val(0xff)); + auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty); + auto scaledBf16_0 = fmul(bf16_0, scaleBf16); + auto scaledBf16_1 = fmul(bf16_1, scaleBf16); + auto i16_0 = bitcast(scaledBf16_0, i16_ty); + auto i16_1 = bitcast(scaledBf16_1, i16_ty); + auto packed = + or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16))); + // Account for NaN in the scale as per the mxfp specification + auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed); + return packed_nan; + }; + // Each thread owns elements of 4 mxfp vectors so we need 4 scales // Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c + // 16, c + 17 @@ -124,8 +142,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { }; for (int j = 0; j < 16; ++j) { - xVals[16 * i + j] = - LLVM::mxfpScaleBf16x2(rewriter, loc, xVals[16 * i + j], si[j / 4]); + xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]); } }