From 1dbe3ac1d9000008d689e472cb8407dcbe2a3214 Mon Sep 17 00:00:00 2001 From: Jiecao Yu Date: Tue, 4 Jun 2024 16:46:59 -0700 Subject: [PATCH] row-wise quant fp8 gemm perf bench | fp8_fast_accum test (#2666) Summary: Modify the FP8 GEMM performance benchmark to add the cases with fast_accum=False and max_num_imprecise_acc=32. Reviewed By: sryap, htyu, jianyuh Differential Revision: D57613375 --- .../gemm/test/fp8_gemm_benchmark.py | 72 +++- .../experimental/gemm/triton_gemm/fp8_gemm.py | 394 ++++++++++++++++-- 2 files changed, 434 insertions(+), 32 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py index ae300f7bb8..e28f8eef35 100644 --- a/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py +++ b/fbgemm_gpu/experimental/gemm/test/fp8_gemm_benchmark.py @@ -54,15 +54,34 @@ def _run_benchmark( sec = ms / 1e3 perf_str = f"{tflops / sec:.2f}" print( - f"{(tag + ':').ljust(20)}\tshape {str(shape):<25} tflops {perf_str:<8} ms {ms:.3f}" + f"{(tag + ':').ljust(40)}\tshape {str(shape):<25} tflops {perf_str:<8} ms {ms:.3f}" ) - shapes = [(8192, 8192, 8192), (65536, 8192, 7168), (65536, 3584, 8192)] + shapes = [ + (8192, 8192, 8192), + (65536, 8192, 7168), + (65536, 3584, 8192), + (8192, 14336, 4096), + ] for shape in shapes: _run_benchmark(bf16_bench, shape=shape, tag="bf16") _run_benchmark(scale_row_bench, shape=shape, tag="fp8 scale + row gemm") _run_benchmark(scale_block_bench, shape=shape, tag="fp8 scale + block gemm") - _run_benchmark(row_gemm_bench, shape=shape, tag="fp8 row gemm only") + _run_benchmark( + row_gemm_bench, + shape=shape, + tag="fp8 row gemm only | fp8_fast_accum=True", + ) + _run_benchmark( + row_gemm_bench_no_fast_acc, + shape=shape, + tag="fp8 row gemm only | fp8_fast_accum=False", + ) + _run_benchmark( + row_gemm_bench_imprecise_acc, + shape=shape, + tag="fp8 row gemm only | max_num_imprecise_acc=32", + ) _run_benchmark(block_gemm_bench, shape=shape, tag="fp8 block gemm only") @@ -118,6 +137,53 @@ def run_gemm() -> Tensor: return run_gemm +def row_gemm_bench_no_fast_acc(x: Tensor, w: Tensor) -> Callable[[], Tensor]: + # Benchmark only row-wise gemm, caching scaling. + x_fp8: TensorWrapper + w_fp8: TensorWrapper + x_scale: Tensor + w_scale: Tensor + x_fp8, x_scale = quantize_fp8_row(x) + w_fp8, w_scale = quantize_fp8_row(w) + + def run_gemm() -> Tensor: + return matmul_fp8_row( + x_fp8, + w_fp8, + x_scale, + w_scale, + dot_out_dtype=torch.float32, + allow_tf32=True, + fp8_fast_accum=False, + ) + + return run_gemm + + +def row_gemm_bench_imprecise_acc(x: Tensor, w: Tensor) -> Callable[[], Tensor]: + # Benchmark only row-wise gemm, caching scaling. + x_fp8: TensorWrapper + w_fp8: TensorWrapper + x_scale: Tensor + w_scale: Tensor + x_fp8, x_scale = quantize_fp8_row(x) + w_fp8, w_scale = quantize_fp8_row(w) + + def run_gemm() -> Tensor: + return matmul_fp8_row( + x_fp8, + w_fp8, + x_scale, + w_scale, + dot_out_dtype=torch.float32, + allow_tf32=True, + fp8_fast_accum=True, + imprecise_acc=True, + ) + + return run_gemm + + def scale_block_bench(x: Tensor, w: Tensor) -> Callable[[], Tensor]: def run_gemm() -> Tensor: x_fp8: TensorWrapper diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index e2a41b50f5..5dbe179cbd 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -198,11 +198,6 @@ def get_configs_io_bound() -> List[Config]: "n_key", "k_key", ], - prune_configs_by={ - "early_config_prune": early_config_prune, - "perf_model": estimate_matmul_time, - "top_k": 10, - }, ) @triton.heuristics( { @@ -336,6 +331,295 @@ def _kernel_matmul_fp8_row( tl.atomic_add(C, acc, mask=mask) +@triton.autotune( + configs=MATMUL_CONFIGS, + key=[ + "m_key", + "n_key", + "k_key", + ], +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def _kernel_matmul_fp8_row_no_fast_acc( + A, + B, + C, + M, + N, + K, + m_key, + n_key, + k_key, + A_scale, + B_scale, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_cm, + stride_cn, + dot_out_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + fp8_fast_accum: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + AB_DTYPE: tl.constexpr, +) -> None: + """Matmul kernel of [M, K] @ [N, K] with row-wise scales + + performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles. + + Args: + A (TensorWrapper): [M, K] input tensor. + B (TensorWrapper): [N, K] input tensor. + C (TensorWrapper): [M, N] output tensor. + M (int): M dimension of input tensor. + N (int): N dimension of input tensor. + K (int): K dimension of input tensor. + m_key (int): Autotuning key for M dimension of input tensor. + n_key (int): Autotuning key for N dimension of input tensor. + k_key (int): Autotuning key for K dimension of input tensor. + A_scale (TensorWrapper): [M] scale tensor per row. A / A_scale = original A + B_scale (TensorWrapper): [N] scale tensor per row. B / B_scale = original B + stride_am (int): Stride of M dimension of A. + stride_ak (int): Stride of K dimension of A. + stride_bn (int): Stride of N dimension of B. + stride_bk (int): Stride of K dimension of B. + stride_cm (int): Stride of M dimension of C. + stride_cn (int): Stride of N dimension of C. + dot_out_dtype (torch.dtype): Output type of tensor core. + allow_tf32 (bool): Whether to use TF32 for tensor core. + fp8_fast_accum (bool): Whether to use fast accumulation for tensor core. + BLOCK_M (int): Block size for M dimension. + BLOCK_N (int): Block size for N dimension. + BLOCK_K (int): Block size for K dimension. + GROUP_M (int): Number of groups for M dimension swizzle. + SPLIT_K (int): Number of SM's to launch per row. + EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K. + AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core. + """ + # Matrix multiplication. + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # Re-order program ID for better L2 performance (swizzle). + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # Do matrix multiplication. + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # Pointers. + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE: + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + # fp8_fast_accum = False + acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Invert scaling. + a_scale = tl.load(A_scale + rm, mask=rm < M) + b_scale = tl.load(B_scale + rn, mask=rn < N) + # Invert vector, then multiply on matrix for speed. + inv_a_scale = 1.0 / a_scale + inv_b_scale = 1.0 / b_scale + # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`. + scale = inv_a_scale[:, None] * inv_b_scale[None, :] + acc *= scale + + acc = acc.to(C.dtype.element_ty) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # Handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +@triton.autotune( + configs=MATMUL_CONFIGS, + key=[ + "m_key", + "n_key", + "k_key", + ], +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) +@triton.jit +def _kernel_matmul_fp8_row_imprecise_acc( + A, + B, + C, + M, + N, + K, + m_key, + n_key, + k_key, + A_scale, + B_scale, + stride_am, + stride_ak, + stride_bn, + stride_bk, + stride_cm, + stride_cn, + dot_out_dtype: tl.constexpr, + allow_tf32: tl.constexpr, + fp8_fast_accum: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + AB_DTYPE: tl.constexpr, +) -> None: + """Matmul kernel of [M, K] @ [N, K] with row-wise scales + + performs swizzled matmul in [BLOCK_M, BLOCK_K] with [BLOCK_K, BLOCK_N] tiles. + + Args: + A (TensorWrapper): [M, K] input tensor. + B (TensorWrapper): [N, K] input tensor. + C (TensorWrapper): [M, N] output tensor. + M (int): M dimension of input tensor. + N (int): N dimension of input tensor. + K (int): K dimension of input tensor. + m_key (int): Autotuning key for M dimension of input tensor. + n_key (int): Autotuning key for N dimension of input tensor. + k_key (int): Autotuning key for K dimension of input tensor. + A_scale (TensorWrapper): [M] scale tensor per row. A / A_scale = original A + B_scale (TensorWrapper): [N] scale tensor per row. B / B_scale = original B + stride_am (int): Stride of M dimension of A. + stride_ak (int): Stride of K dimension of A. + stride_bn (int): Stride of N dimension of B. + stride_bk (int): Stride of K dimension of B. + stride_cm (int): Stride of M dimension of C. + stride_cn (int): Stride of N dimension of C. + dot_out_dtype (torch.dtype): Output type of tensor core. + allow_tf32 (bool): Whether to use TF32 for tensor core. + fp8_fast_accum (bool): Whether to use fast accumulation for tensor core. + BLOCK_M (int): Block size for M dimension. + BLOCK_N (int): Block size for N dimension. + BLOCK_K (int): Block size for K dimension. + GROUP_M (int): Number of groups for M dimension swizzle. + SPLIT_K (int): Number of SM's to launch per row. + EVEN_K (bool): Whether K is evenly divisible by BLOCK_K * SPLIT_K. + AB_DTYPE (bool): Wether to cast A and B to C.dtype before tensor core. + """ + # Matrix multiplication. + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # Re-order program ID for better L2 performance (swizzle). + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # Do matrix multiplication. + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # Pointers. + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE: + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) + if fp8_fast_accum: + acc = tl.dot( + a, + b, + acc, + max_num_imprecise_acc=32, + out_dtype=dot_out_dtype, + allow_tf32=allow_tf32, + ) + else: + acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # Invert scaling. + a_scale = tl.load(A_scale + rm, mask=rm < M) + b_scale = tl.load(B_scale + rn, mask=rn < N) + # Invert vector, then multiply on matrix for speed. + inv_a_scale = 1.0 / a_scale + inv_b_scale = 1.0 / b_scale + # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`. + scale = inv_a_scale[:, None] * inv_b_scale[None, :] + acc *= scale + + acc = acc.to(C.dtype.element_ty) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # Handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + def matmul_fp8_row( a: TensorWrapper, b: TensorWrapper, @@ -344,6 +628,7 @@ def matmul_fp8_row( dot_out_dtype: Optional[torch.dtype] = None, allow_tf32: bool = True, fp8_fast_accum: bool = True, + imprecise_acc: bool = False, ) -> torch.Tensor: """ Performs matmul on [M, K] and [N, K] fp8 matrices with row-wise scalings [M], [N]. @@ -379,30 +664,81 @@ def grid(META): META["SPLIT_K"], ) - _kernel_matmul_fp8_row[grid]( - a, - b, - c, - M, - N, - K, - m_key, - n_key, - k_key, - a_scale, - b_scale, - a.stride(0), - a.stride(1), - b.stride(0), - b.stride(1), - c.stride(0), - c.stride(1), - dot_out_dtype=dot_out_dtype_triton, - allow_tf32=allow_tf32, - fp8_fast_accum=fp8_fast_accum, - GROUP_M=8, - AB_DTYPE=False, - ) + if imprecise_acc: + _kernel_matmul_fp8_row_imprecise_acc[grid]( + a, + b, + c, + M, + N, + K, + m_key, + n_key, + k_key, + a_scale, + b_scale, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + dot_out_dtype=dot_out_dtype_triton, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + GROUP_M=8, + AB_DTYPE=False, + ) + elif fp8_fast_accum: + _kernel_matmul_fp8_row[grid]( + a, + b, + c, + M, + N, + K, + m_key, + n_key, + k_key, + a_scale, + b_scale, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + dot_out_dtype=dot_out_dtype_triton, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + GROUP_M=8, + AB_DTYPE=False, + ) + else: + _kernel_matmul_fp8_row_no_fast_acc[grid]( + a, + b, + c, + M, + N, + K, + m_key, + n_key, + k_key, + a_scale, + b_scale, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + dot_out_dtype=dot_out_dtype_triton, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + GROUP_M=8, + AB_DTYPE=False, + ) return c