Skip to content

Commit

Permalink
Reorder load and scaling code to allow latency hidding for block-wise…
Browse files Browse the repository at this point in the history
… scaled GEMMs (#2600)

Summary:

The compiler may not do a good job at reordering instructions for better latency hiding due to various reasons. Thus I'm tweaking the kernel code here.

Previously in the block-wise scaled GEMM kernel, the scaling logic followed `tl.load` and the compiler was not able to move the logic before the loads once the loads are pipelined. This created a situation where the scaling logic was blocked by the load barriers, which is unnecessary as they are independent. Since the barrier is only needed by the `dot` operation, I'm moving the scaling logic before the loads. 


 {F1640448911}


While we should fix the compiler to be more robust, I'm making a source change as a workaround.

Differential Revision: D57473133
  • Loading branch information
htyu committed May 20, 2024
1 parent 37c283c commit 64df84e
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,17 +529,6 @@ def _kernel_matmul_fp8_block(
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
scale_next = 0.0
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)

a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
if AB_DTYPE:
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
# Note: Due to split_k access "pid_k" = k * SPLIT_K + pid_z
# Access a_scale[pid_m, k * SPLIT_K + pid_z]
# and b_scale[k * SPLIT_K + pid_z, pid_n]
Expand Down Expand Up @@ -573,6 +562,17 @@ def _kernel_matmul_fp8_block(
inv_scale = 1.0 / scale
scale_next_inv_scale = scale_next / scale

if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)

a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
if AB_DTYPE:
a = a.to(C.dtype.element_ty)
b = b.to(C.dtype.element_ty)
if fp8_fast_accum:
acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)

Expand Down

0 comments on commit 64df84e

Please sign in to comment.