Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Add initial support for scaled_dot(mxfp8, fp8) #4994

Merged
merged 5 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,25 @@ LogicalResult UpcastMXFPOp::verify() {
"all dimensions except the last must match between operands");
}

auto layoutX = xTy.getEncoding();
if (!layoutX || !isa<DotOperandEncodingAttr>(layoutX)) {
auto dotEncoding =
dyn_cast_or_null<DotOperandEncodingAttr>(xTy.getEncoding());
if (!dotEncoding) {
return emitOpError("Expected a DotOperandEncodingAttr for values");
}
auto layoutScale = scaleTy.getEncoding();
if (!layoutScale || !isa<BlockedEncodingAttr>(layoutScale)) {

auto blockedScale =
dyn_cast_or_null<BlockedEncodingAttr>(scaleTy.getEncoding());
if (!blockedScale) {
return emitOpError("Expected a BlockOperandEncoding for scales");
}
auto blockedScale = cast<BlockedEncodingAttr>(layoutScale);

// Necessary to keep all of the scales of a given block of values in the same
// warp
auto threadsPerWarp = blockedScale.getThreadsPerWarp();
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
return emitOpError("Expected threads per warp to be {16, 2}");
if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
// Necessary to keep all of the scales of a given block of values in the
// same warp
auto threadsPerWarp = blockedScale.getThreadsPerWarp();
if (threadsPerWarp != ArrayRef<unsigned>({16, 2})) {
return emitOpError("Expected threads per warp to be {16, 2}");
}
}

return success();
Expand Down
39 changes: 24 additions & 15 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3322,19 +3322,24 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
assert 'wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3' in ptx


@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps",
[(M, N, K, col_a, col_b, type_a, type_b, 4)
@pytest.mark.parametrize("M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack",
[(M, N, K, col_a, col_b, type_a, type_b, 4, mma, kpack)
for M, N, K in itertools.product([32, 64, 128], [32, 64, 128], [64, 128])
for col_a, col_b in itertools.product([True, False], repeat=2)
for type_a in ["e2m1", "e4m3", "e5m2"]
for type_b in ["e4m3", "e5m2"]])
def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, device):
if not is_cuda():
pytest.skip("scaled_dot only supported on CUDA")
else:
for type_b in ["e4m3", "e5m2"]
for mma in ([32, 16] if is_hip() else [16])
for kpack in ([1, 2] if is_hip() else [1])])
def test_scaled_dot(M, N, K, col_a, col_b, type_a, type_b, num_warps, mma, kpack, device):
if is_cuda():
cc = torch.cuda.get_device_capability()
if cc < (8, 9):
pytest.skip("float8e4nv not supported on CUDA < 8.9")
if is_hip():
if type_a != "e5m2" or type_b != "e5m2":
pytest.skip(f"scaled_dot({type_a}, {type_b}) not yet implemented for HIP")
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
antiagainst marked this conversation as resolved.
Show resolved Hide resolved

@triton.jit
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, out,
Expand Down Expand Up @@ -3493,22 +3498,26 @@ def make_finite(x, dtype):
x = make_finite(x, type_a)
y = make_finite(y, type_b)

kernel_kwargs = {"num_warps": num_warps}
if is_hip():
kernel_kwargs["kpack"] = kpack
kernel_kwargs["matrix_instr_nonkdim"] = mma
z = x.new_empty((M, N), dtype=torch.bfloat16)
pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b,
num_warps=num_warps)
pgm = dot_scale_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), z, M, N, K, type_a, type_b, **kernel_kwargs)

z_ref = dot_scale_ref(x, scale_x, y, type_a, type_b)

# generous rtol as we are sampling the whole range of floats
torch.testing.assert_close(z, z_ref, atol=1e-5, rtol=1e-2)

# make sure ld/st are vectorized
ptx = pgm.asm['ptx']
if (max(M, N) * K) // (num_warps * 32) >= 4:
assert 'ld.global.v4' in ptx
if M * N // (num_warps * 32) >= 4:
assert 'st.global.v4' in ptx
assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx)
if is_cuda():
ptx = pgm.asm['ptx']
if (max(M, N) * K) // (num_warps * 32) >= 4:
assert 'ld.global.v4' in ptx
if M * N // (num_warps * 32) >= 4:
assert 'st.global.v4' in ptx
assert re.search(r'mma.sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.bf16.bf16', ptx)


@pytest.mark.interpreter
Expand Down
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ add_triton_library(TritonAMDGPUToLLVM
OptimizeLDSUtility.cpp
SPMDOpToLLVM.cpp
SchedInstructions.cpp
UpcastMXFPToLLVM.cpp

DEPENDS
TritonAMDGPUConversionPassIncGen
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

void populateUpcastMXFPToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfo &targetInfo,
PatternBenefit benefit);

} // namespace mlir::triton::AMD

#endif
3 changes: 3 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ struct ConvertTritonAMDGPUToLLVM

mlir::triton::AMD::populateTritonAMDGPUToLLVMPatterns(typeConverter,
patterns, AMDBenefit);
mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(typeConverter, patterns,
targetInfo, AMDBenefit);

// TODO(thomas): this should probably be done in a separate step to not
// interfere with our own lowering of arith ops. Add arith/math's patterns
Expand All @@ -223,6 +225,7 @@ struct ConvertTritonAMDGPUToLLVM
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
targetInfo, commonBenefit);
mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);

if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) {
return signalPassFailure();
}
Expand Down
146 changes: 146 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#include "PatternTritonGPUOpToLLVM.h"

#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include <array>

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;

namespace {

Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v,
Value scale) {
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
Value scaledBf16 = fmul(v, scaleBf16);
// Account for NaN in the scale as per the mxfp specification.
return select(scaleIsNan, nanBf16, scaledBf16);
};

class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
private:
const TargetInfoBase &targetInfo;

public:
UpcastMXFPOpPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo, PatternBenefit benefit)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

LogicalResult
matchAndRewrite(UpcastMXFPOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto fpType = op.getFpType();
if (!(fpType == F8F6F4Type::E4M3 || fpType == F8F6F4Type::E5M2))
return rewriter.notifyMatchFailure(op, "NYI: non-mxfp8 cases");

Location loc = op.getLoc();
auto xVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
auto scaleVals = unpackLLElements(loc, adaptor.getScale(), rewriter);
LDBG("x: " << xVals.size() << " x " << xVals.front().getType());
LDBG("scale: " << scaleVals.size() << " x " << scaleVals.front().getType());

// When we lower scaled dot op, we made sure to distribute K only on one
// warp. MXFP spec mandates 1 scale value for every 32 onsecutive values
// along the K dimension. So in total each thread should read 32x main
// element values.
if (xVals.size() != scaleVals.size() * 32)
return rewriter.notifyMatchFailure(op, "unsupported problem size");

auto dotEncoding =
cast<DotOperandEncodingAttr>(op.getSrc().getType().getEncoding());
if (dotEncoding.getOpIdx() == 1)
return rewriter.notifyMatchFailure(op, "NYI: dot RHS");
auto mfmaEncoding = dyn_cast<AMDMfmaEncodingAttr>(dotEncoding.getParent());
if (!mfmaEncoding)
return rewriter.notifyMatchFailure(op, "NYI: non-mfma dot operand");
LDBG("mfma: " << mfmaEncoding);

int mDim = mfmaEncoding.getMDim();
if (mDim != 32 && mDim != 16)
return rewriter.notifyMatchFailure(op, "NYI: non-mfma32/16 intrinsics");

int numThreads = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
Value warpSize = i32_val(numThreads);
Value tid = tid_val();
Value warpId = udiv(tid, warpSize);
Value laneId = urem(tid, warpSize);

// Given that MFMA layout for the A tensor arranges thread in a column-major
// manner, for the current tid, it's at row (tid % mDim). When we set up
// blocked layout for the A scale tensor, we made sure that it has a
// threadsPerWarp = [M=mDim, K=64/mDim]. So the threads holding scale values
// for the current thread starts at ((tid % mDim) * (64 / mDim)).
Value offset = mul(urem(laneId, i32_val(mDim)), i32_val(numThreads / mDim));

if (mDim == 32) {
// One mfma32 intrinsic processes a 32x8 A tensor slice. Due to how we
// tile, the same warp owns the whole K dim. Inside a warp, each thread
// only holds 4 consecutive elements along K--a 1x4 vector. We need to
// tile the warp 4 times to cover 32 values along K. So for a thread, the
// first 4 1x4 vectors it holds shares the first scale value at row (tid %
// mDim). the second 4 1x4 vectors shares the second scale value at row
// (tid % mDim); and so forth.
std::array<Value, 2> scaleThreads = {offset, add(offset, i32_val(1))};

for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
std::array<Value, 2> si = {
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]),
};

for (int j = 0; j < 32; ++j) {
int index = 32 * i + j;
xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]);
}
}
} else {
assert(mDim == 16);
// One mfma16 intrinsic processes a 16x16 A tensor slice. Similarly, we
// need to tile the warp 2 times to cover 32 valeus. So for a thread, the
// first 2 1x4 vectors shares the first scale value at row (tid % mDim).
std::array<Value, 4> scaleThreads = {offset, add(offset, i32_val(1)),
add(offset, i32_val(2)),
add(offset, i32_val(3))};

for (auto [i, scaleVal] : llvm::enumerate(scaleVals)) {
auto si = std::array<Value, 4>{
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[0]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[1]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[2]),
targetInfo.shuffleIdx(rewriter, loc, scaleVal, scaleThreads[3]),
};

for (int j = 0; j < 32; ++j) {
int index = 32 * i + j;
xVals[index] = mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]);
}
}
}

Value result =
packLLElements(loc, getTypeConverter(), xVals, rewriter, op.getType());
rewriter.replaceOp(op, result);
return success();
}
};
} // anonymous namespace

void mlir::triton::AMD::populateUpcastMXFPToLLVMPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfo &targetInfo, PatternBenefit benefit) {
patterns.add<UpcastMXFPOpPattern>(typeConverter, targetInfo, benefit);
}
Loading
Loading