Skip to content

Commit

Permalink
[AMD] Add initial support for scaled_dot(mxfp8, fp8) (triton-lang#4994)
Browse files Browse the repository at this point in the history
This commit adds initial support for scaled_dot with
mxfp8 LHS and fp8 RHS. It supports both mfma32
and mfma16 intrinsic variants.

Right now we are missing software emulation for
`Float8E4M3FN` type, so this only enables for
`Float8E5M2`.
  • Loading branch information
antiagainst authored Oct 28, 2024
1 parent 8cdba56 commit 3549db8
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 27 deletions.
24 changes: 14 additions & 10 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,25 @@ LogicalResult UpcastMXFPOp::verify() {
"all dimensions except the last must match between operands");
}

auto layoutX = xTy.getEncoding();
if (!layoutX || !isa<DotOperandEncodingAttr>(layoutX)) {
auto dotEncoding =
dyn_cast_or_null<DotOperandEncodingAttr>(xTy.getEncoding());
if (!dotEncoding) {
return emitOpError("Expected a DotOperandEncodingAttr for values");
}
auto layoutScale = scaleTy.getEncoding();
if (!layoutScale || !isa<BlockedEncodingAttr>(layoutScale)) {

auto blockedScale =
dyn_cast_or_null<BlockedEncodingAttr>(scaleTy.getEncoding());
if (!blockedScale) {
return emitOpError("Expected a BlockOperandEncoding for scales");
}
auto blockedScale = cast<BlockedEncodingAttr>(layoutScale);

// Necessary to keep all of the scales of a given block of values in the same
// warp
auto threadsPerWarp = blockedScale.getThreadsPerWarp();
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
return emitOpError("Expected threads per warp to be {16, 2}");
if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
// Necessary to keep all of the scales of a given block of values in the
// same warp
auto threadsPerWarp = blockedScale.getThreadsPerWarp();
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
return emitOpError("Expected threads per warp to be {16, 2}");
}
}

return success();
Expand Down
39 changes: 24 additions & 15 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3322,19 +3322,24 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx


@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps",
[(M, N, K, col_a, col_b, type_a, type_b, 4)
@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack",
[(M, N, K, col_a, col_b, type_a, type_b, 4, mma, kpack)
for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128])
for col_a, col_b in itertools.product([True, False], repeat=2)
for type_a in ["e2m1", "e4m3", "e5m2"]
for type_b in ["e4m3", "e5m2"]])
def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device):
if not is_cuda():
pytest.skip("scaled_dot only supported on CUDA")
else:
for type_b in ["e4m3", "e5m2"]
for mma in ([32, 16] if is_hip() else [16])
for kpack in ([1, 2] if is_hip() else [1])])
def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device):
if is_cuda():
cc = torch.cuda.get_device_capability()
if cc < (8, 9):
pytest.skip("float8e4nv not supported on CUDA < 8.9")
if is_hip():
if type_a != "e5m2" or type_b != "e5m2":
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")

@triton.jit
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out,
Expand Down Expand Up @@ -3493,22 +3498,26 @@ def make_finite(x, dtype):
x = make_finite(x, type_a)
y = make_finite(y, type_b)

kernel_kwargs = {"num_warps": num_warps}
if is_hip():
kernel_kwargs["kpack"] = kpack
kernel_kwargs["matrix_instr_nonkdim"] = mma
z = x.new_empty((M, N), dtype=torch.bfloat16)
pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b,
num_warps=num_warps)
pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, **kernel_kwargs)

z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b)

# generous rtol as we are sampling the whole range of floats
torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2)

# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
if (max(M, N) * K) // (num_warps * 32) >= 4:
assert 'ld.global.v4' in ptx
if M * N // (num_warps * 32) >= 4:
assert 'st.global.v4' in ptx
assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx)
if is_cuda():
ptx = pgm.asm['ptx']
if (max(M, N) * K) // (num_warps * 32) >= 4:
assert 'ld.global.v4' in ptx
if M * N // (num_warps * 32) >= 4:
assert 'st.global.v4' in ptx
assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx)


@pytest.mark.interpreter
Expand Down
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_triton_library(TritonAMDGPUToLLVM
OptimizeLDSUtility.cpp
SPMDOpToLLVM.cpp
SchedInstructions.cpp
UpcastMXFPToLLVM.cpp

DEPENDS
TritonAMDGPUConversionPassIncGen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfo &targetInfo,
PatternBenefit benefit);

} // namespace mlir::triton::AMD

#endif
3 changes: 3 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ struct ConvertTritonAMDGPUToLLVM

mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter,
patterns, AMDBenefit);
mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns,
targetInfo, AMDBenefit);

// TODO(thomas): this should probably be done in a separate step to not
// interfere with our own lowering of arith ops. Add arith/math's patterns
Expand All @@ -223,6 +225,7 @@ struct ConvertTritonAMDGPUToLLVM
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, commonBenefit);
mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) {
return signalPassFailure();
}
Expand Down
146 changes: 146 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#include "PatternTritonGPUOpToLLVM.h"

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include <array>

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;

namespace {

Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v,
Value scale) {
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
Value scaledBf16 = fmul(v, scaleBf16);
// Account for NaN in the scale as per the mxfp specification.
return select(scaleIsNan, nanBf16, scaledBf16);
};

class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
private:
const TargetInfoBase &targetInfo;

public:
UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo, PatternBenefit benefit)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

LogicalResult
matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto fpType = op.getFpType();
if (!(fpType == F8F6F4Type::E4M3 || fpType == F8F6F4Type::E5M2))
return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases");

Location loc = op.getLoc();
auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter);
LDBG("x: " << xVals.size() << " x " << xVals.front().getType());
LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType());

// When we lower scaled dot op, we made sure to distribute K only on one
// warp. MXFP spec mandates 1 scale value for every 32 onsecutive values
// along the K dimension. So in total each thread should read 32x main
// element values.
if (xVals.size() != scaleVals.size() * 32)
return rewriter.notifyMatchFailure(op, "unsupported problem size");

auto dotEncoding =
cast<DotOperandEncodingAttr>(op.getSrc().getType().getEncoding());
if (dotEncoding.getOpIdx() == 1)
return rewriter.notifyMatchFailure(op, "NYI: dot RHS");
auto mfmaEncoding = dyn_cast<AMDMfmaEncodingAttr>(dotEncoding.getParent());
if (!mfmaEncoding)
return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand");
LDBG("mfma: " << mfmaEncoding);

int mDim = mfmaEncoding.getMDim();
if (mDim != 32 && mDim != 16)
return rewriter.notifyMatchFailure(op, "NYI: non-mfma32/16 intrinsics");

int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
Value warpSize = i32_val(numThreads);
Value tid = tid_val();
Value warpId = udiv(tid, warpSize);
Value laneId = urem(tid, warpSize);

// Given that MFMA layout for the A tensor arranges thread in a column-major
// manner, for the current tid, it's at row (tid % mDim). When we set up
// blocked layout for the A scale tensor, we made sure that it has a
// threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values
// for the current thread starts at ((tid % mDim) * (64 / mDim)).
Value offset = mul(urem(laneId, i32_val(mDim)), i32_val(numThreads / mDim));

if (mDim == 32) {
// One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we
// tile, the same warp owns the whole K dim. Inside a warp, each thread
// only holds 4 consecutive elements along K--a 1x4 vector. We need to
// tile the warp 4 times to cover 32 values along K. So for a thread, the
// first 4 1x4 vectors it holds shares the first scale value at row (tid %
// mDim). the second 4 1x4 vectors shares the second scale value at row
// (tid % mDim); and so forth.
std::array<Value, 2> scaleThreads = {offset, add(offset, i32_val(1))};

for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
std::array<Value, 2> si = {
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]),
};

for (int j = 0; j < 32; ++j) {
int index = 32 * i + j;
xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]);
}
}
} else {
assert(mDim == 16);
// One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we
// need to tile the warp 2 times to cover 32 valeus. So for a thread, the
// first 2 1x4 vectors shares the first scale value at row (tid % mDim).
std::array<Value, 4> scaleThreads = {offset, add(offset, i32_val(1)),
add(offset, i32_val(2)),
add(offset, i32_val(3))};

for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
auto si = std::array<Value, 4>{
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[2]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]),
};

for (int j = 0; j < 32; ++j) {
int index = 32 * i + j;
xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]);
}
}
}

Value result =
packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType());
rewriter.replaceOp(op, result);
return success();
}
};
} // anonymous namespace

void mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfo &targetInfo, PatternBenefit benefit) {
patterns.add<UpcastMXFPOpPattern>(typeConverter, targetInfo, benefit);
}
Loading

0 comments on commit 3549db8

Please sign in to comment.