From 64df84eb2569c8573e35840cdfb4243cd3ff548b Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Thu, 16 May 2024 19:14:28 -0700 Subject: [PATCH] Reorder load and scaling code to allow latency hidding for block-wise scaled GEMMs (#2600) Summary: The compiler may not do a good job at reordering instructions for better latency hiding due to various reasons. Thus I'm tweaking the kernel code here. Previously in the block-wise scaled GEMM kernel, the scaling logic followed `tl.load` and the compiler was not able to move the logic before the loads once the loads are pipelined. This created a situation where the scaling logic was blocked by the load barriers, which is unnecessary as they are independent. Since the barrier is only needed by the `dot` operation, I'm moving the scaling logic before the loads. {F1640448911} While we should fix the compiler to be more robust, I'm making a source change as a workaround. Differential Revision: D57473133 --- .../experimental/gemm/triton_gemm/fp8_gemm.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 181a51568c..2bd270e14f 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -529,17 +529,6 @@ def _kernel_matmul_fp8_block( _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) scale_next = 0.0 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - k_remaining = K - k * (BLOCK_K * SPLIT_K) - - a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) - b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) - if AB_DTYPE: - a = a.to(C.dtype.element_ty) - b = b.to(C.dtype.element_ty) # Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z # Access a_scale[pid_m, k * SPLIT_K + pid_z] # and b_scale[k * SPLIT_K + pid_z, pid_n] @@ -573,6 +562,17 @@ def _kernel_matmul_fp8_block( inv_scale = 1.0 / scale scale_next_inv_scale = scale_next / scale + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE: + a = a.to(C.dtype.element_ty) + b = b.to(C.dtype.element_ty) if fp8_fast_accum: acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)