Skip to content

Commit

Permalink
Drop packing logic given we process standalone elements
Browse files Browse the repository at this point in the history
  • Loading branch information
antiagainst committed Oct 25, 2024
1 parent f1349db commit 27d403b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 36 deletions.
8 changes: 0 additions & 8 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,14 +391,6 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
return base;
}

// -----------------------------------------------------------------------
// MXFP utilities
// -----------------------------------------------------------------------

// Split bf16x2 into 2 bf16, scale each of them, and pack them back
Value mxfpScaleBf16x2(RewriterBase &rewriter, Location loc, Value v,
Value scale);
} // namespace LLVM

/* ------------------------------------ */
Expand Down
18 changes: 0 additions & 18 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,23 +856,5 @@ SmallVector<Value> getWrappedMultiDimOffset(
return multiDimOffsetWrapped;
}

Value mxfpScaleBf16x2(RewriterBase &rewriter, Location loc, Value v,
Value scale) {
// Split bf16x2 into 2 bf16, scale each of them, and pack them back
// TODO Is it true that the bfloats are always packed as bf16x2?
auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty);
auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty);
auto scaleIsNan = icmp_eq(scale, i8_val(0xff));
auto scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
auto scaledBf16_0 = fmul(bf16_0, scaleBf16);
auto scaledBf16_1 = fmul(bf16_1, scaleBf16);
auto i16_0 = bitcast(scaledBf16_0, i16_ty);
auto i16_1 = bitcast(scaledBf16_1, i16_ty);
auto packed = or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16)));
// Account for NaN in the scale as per the mxfp specification
auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed);
return packed_nan;
};

} // namespace LLVM
} // namespace mlir
27 changes: 19 additions & 8 deletions third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ using namespace mlir::triton;
using namespace mlir::triton::gpu;

namespace {

Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v,
Value scale) {
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
Value scaledBf16 = fmul(v, scaleBf16);
// Account for NaN in the scale as per the mxfp specification.
return select(scaleIsNan, nanBf16, scaledBf16);
};

class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
private:
const TargetInfoBase &targetInfo;
Expand All @@ -44,8 +55,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
// When we lower scaled dot op, we made sure to distribute K only on one
// warp. MXFP spec mandates 1 scale value for every 32 onsecutive values
// along the K dimension. So in total each thread should read 32x main
// element values. Those fp16 values are packed in xVals to be 32-bit.
assert(xVals.size() == scaleVals.size() * 32 / 2);
// element values.
assert(xVals.size() == scaleVals.size() * 32);

auto dotEncoding =
cast<DotOperandEncodingAttr>(op.getSrc().getType().getEncoding());
Expand Down Expand Up @@ -90,9 +101,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]),
};

for (int j = 0; j < 16; ++j) {
xVals[16 * i + j] = LLVM::mxfpScaleBf16x2(
rewriter, loc, xVals[16 * i + j], si[j / 8]);
for (int j = 0; j < 32; ++j) {
int index = 32 * i + j;
xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]);
}
}
} else {
Expand All @@ -112,9 +123,9 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]),
};

for (int j = 0; j < 16; ++j) {
xVals[16 * i + j] = LLVM::mxfpScaleBf16x2(
rewriter, loc, xVals[16 * i + j], si[j / 4]);
for (int j = 0; j < 32; ++j) {
int index = 32 * i + j;
xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]);
}
}
}
Expand Down
21 changes: 19 additions & 2 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
xVals = unpackFP4Elements(loc, rewriter, xVals, laneId);
}

auto scaleBf16x2 = [&loc, &rewriter](Value v, Value s) -> Value {
// Split bf16x2 into 2 bf16, scale each of them, and pack them back
// TODO Is it true that the bfloats are always packed as bf16x2?
auto bf16_0 = bitcast(trunc(i16_ty, v), bf16_ty);
auto bf16_1 = bitcast(trunc(i16_ty, lshr(v, i32_val(16))), bf16_ty);
auto scaleIsNan = icmp_eq(s, i8_val(0xff));
auto scaleBf16 = bitcast(shl(zext(i16_ty, s), i16_val(7)), bf16_ty);
auto scaledBf16_0 = fmul(bf16_0, scaleBf16);
auto scaledBf16_1 = fmul(bf16_1, scaleBf16);
auto i16_0 = bitcast(scaledBf16_0, i16_ty);
auto i16_1 = bitcast(scaledBf16_1, i16_ty);
auto packed =
or_(zext(i32_ty, i16_0), shl(zext(i32_ty, i16_1), i32_val(16)));
// Account for NaN in the scale as per the mxfp specification
auto packed_nan = select(scaleIsNan, i32_val(0x7fff7fff), packed);
return packed_nan;
};

// Each thread owns elements of 4 mxfp vectors so we need 4 scales
// Letting c = tid / 4 * 2, we need the elements from threads c, c + 1, c +
// 16, c + 17
Expand All @@ -124,8 +142,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
};

for (int j = 0; j < 16; ++j) {
xVals[16 * i + j] =
LLVM::mxfpScaleBf16x2(rewriter, loc, xVals[16 * i + j], si[j / 4]);
xVals[16 * i + j] = scaleBf16x2(xVals[16 * i + j], si[j / 4]);
}
}

Expand Down

0 comments on commit 27d403b

Please sign in to comment.