forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Test chained dot (FP8 case, shuffle conversion)
Summary: Tests the shuffle conversion in case of FP8 dot that consumes the output of another FP8 dot. Test: pytest third_party/amd/python/test/test_chained_dot_fp8.py
- Loading branch information
Showing
1 changed file
with
173 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
""" | ||
Testing the (FP8) case of a dot op that consumes the output (MFMA) of | ||
another dot op as an input. | ||
""" | ||
|
||
import math | ||
import pytest | ||
import sys | ||
import torch | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') | ||
float8: tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz | ||
torch.manual_seed(42) | ||
|
||
|
||
@triton.jit | ||
def _chained_dot( | ||
Q, | ||
K, | ||
V, | ||
Out, | ||
q_desc, | ||
k_desc, | ||
v_desc, | ||
s_sc, | ||
s_desc, | ||
o_sc, | ||
stride_qz, | ||
stride_qm, | ||
stride_qd, | ||
stride_kz, | ||
stride_kn, | ||
stride_kd, | ||
stride_vz, | ||
stride_vd, | ||
stride_vn, | ||
stride_oz, | ||
stride_om, | ||
stride_od, | ||
Z, | ||
M, | ||
N, | ||
BLOCK_D: tl.constexpr, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_N: tl.constexpr, | ||
USE_FP8: tl.constexpr, | ||
): | ||
start_m = tl.program_id(0) | ||
off_z = tl.program_id(1) | ||
qkv_offset = off_z * stride_qz | ||
Q_block_ptr = tl.make_block_ptr(base=Q + qkv_offset, shape=(N, BLOCK_D), strides=(stride_qm, stride_qd), | ||
offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_D), order=(1, 0)) | ||
K_block_ptr = tl.make_block_ptr(base=K + qkv_offset, shape=(BLOCK_D, N), strides=(stride_kd, stride_kn), | ||
offsets=(0, 0), block_shape=(BLOCK_D, BLOCK_N), order=(0, 1)) | ||
V_block_ptr = tl.make_block_ptr(base=V + qkv_offset, shape=(N, BLOCK_D), strides=(stride_vn, stride_vd), | ||
offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_D), order=(0, 1)) | ||
|
||
s_scale = q_desc * k_desc * s_sc | ||
acc_scale = s_desc * v_desc * o_sc | ||
|
||
q = tl.load(Q_block_ptr) | ||
|
||
acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) | ||
lo, hi = 0, N | ||
for start_n in range(lo, hi, BLOCK_N): | ||
start_n = tl.multiple_of(start_n, BLOCK_N) | ||
|
||
k = tl.load(K_block_ptr) | ||
s = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) | ||
s += tl.dot(q, k) | ||
|
||
if USE_FP8: | ||
s *= s_scale | ||
|
||
v = tl.load(V_block_ptr) | ||
acc += tl.dot(s.to(v.dtype), v) | ||
|
||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) | ||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) | ||
|
||
if USE_FP8: | ||
acc *= acc_scale | ||
|
||
O_block_ptr = tl.make_block_ptr(base=Out + qkv_offset, shape=(N, BLOCK_D), strides=(stride_om, stride_od), | ||
offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_D), order=(1, 0)) | ||
tl.store(O_block_ptr, acc.to(Out.type.element_ty)) | ||
|
||
|
||
class chained_dot_fn(torch.autograd.Function): | ||
|
||
@staticmethod | ||
def forward(ctx, q, k, v, msize=32, q_desc=1.0, k_desc=1.0, v_desc=1.0, s_sc=1.0, s_desc=1.0, o_sc=1.0): | ||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] | ||
assert Lq == Lk and Lk == Lv | ||
assert Lk in {16, 32, 64, 128} | ||
assert msize in {16, 32} | ||
o = torch.empty_like(q, dtype=v.dtype) | ||
|
||
BLOCK_M = 128 if q.dtype == float8 else 256 | ||
if BLOCK_M > q.shape[1]: | ||
BLOCK_M = int(math.pow(2, math.floor(math.log2(q.shape[1])))) | ||
BLOCK_N = 32 | ||
if BLOCK_N > k.shape[1]: | ||
BLOCK_N = int(math.pow(2, math.floor(math.log2(k.shape[1])))) | ||
waves_per_eu = 2 | ||
num_warps = 4 if q.dtype == float8 else 8 | ||
num_stages = 1 | ||
|
||
grid = (triton.cdiv(q.shape[1], BLOCK_M), q.shape[0], 1) | ||
|
||
_chained_dot[grid](q, k, v, o, q_desc, | ||
k_desc, v_desc, s_sc, s_desc, o_sc, q.stride(0), q.stride(1), q.stride(2), k.stride(0), | ||
k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0), o.stride(1), | ||
o.stride(2), Z=q.shape[0], M=q.shape[1], N=k.shape[1], BLOCK_D=Lk, BLOCK_M=BLOCK_M, | ||
BLOCK_N=BLOCK_N, USE_FP8=(q.dtype == float8), waves_per_eu=waves_per_eu, num_warps=num_warps, | ||
num_stages=num_stages, matrix_instr_nonkdim=msize) | ||
|
||
return o | ||
|
||
|
||
chained_dot = chained_dot_fn.apply | ||
|
||
|
||
def to_float8(x, dtype=float8, margin: float = 1.0): | ||
finfo = torch.finfo(dtype) | ||
scale = finfo.max / x.abs().max().clamp(min=1e-12) | ||
scale = math.pow(2, math.floor(math.log2(scale.float().item())) - margin) | ||
x_scaled = (x.float() * scale).clamp(min=finfo.min, max=finfo.max) | ||
return x_scaled.to(dtype), scale, 1.0 / scale | ||
|
||
|
||
@pytest.mark.parametrize('M, N, D, dtype, msize', [(*shape, dtype, msize) | ||
for shape in [(128, 64, 32), (256, 128, 128)] | ||
for dtype in ['fp8'] | ||
for msize in [16, 32]]) | ||
def test_chained_dot(M, N, D, dtype, msize): | ||
if dtype == 'fp8': | ||
assert float8 is not None | ||
|
||
BATCH = 1 | ||
q = torch.empty((BATCH, M, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) | ||
k = torch.empty((BATCH, N, D), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) | ||
v = torch.empty((BATCH, D, N), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5) | ||
|
||
if dtype == 'fp8': | ||
q_f8, _, q_desc = to_float8(q) | ||
k_f8, _, k_desc = to_float8(k) | ||
v_f8, _, v_desc = to_float8(v) | ||
|
||
s = torch._scaled_mm(q_f8[0], k_f8[0].transpose(0, 1), out_dtype=torch.float32, | ||
scale_a=torch.tensor(q_desc, dtype=torch.float32, device="cuda"), | ||
scale_b=torch.tensor(k_desc, dtype=torch.float32, device="cuda")) | ||
s_f8, s_sc, s_desc = to_float8(s) | ||
ref = torch._scaled_mm(s_f8, v_f8[0].transpose(0, 1), out_dtype=torch.float32, | ||
scale_a=torch.tensor(s_desc, dtype=torch.float32, device="cuda"), | ||
scale_b=torch.tensor(v_desc, dtype=torch.float32, device="cuda")) | ||
ref_f8, ref_sc, _ = to_float8(ref) | ||
|
||
tri_out = chained_dot(q_f8, k_f8, v_f8, msize, q_desc, k_desc, v_desc, s_sc, s_desc, ref_sc) | ||
|
||
assert tri_out.isnan().sum() == 0 | ||
torch.testing.assert_close(tri_out[0].float(), ref_f8.float(), atol=1e-2, rtol=0) | ||
|
||
else: | ||
s = torch.matmul(q, k.transpose(1, 2)) | ||
ref = torch.matmul(s, v.transpose(1, 2)) | ||
|
||
tri_out = chained_dot(q, k, v, msize) | ||
torch.testing.assert_close(tri_out, ref, atol=1e-2, rtol=0) |