From 0b6a4aa358a6a4a61c27126272eb6a93c41f3287 Mon Sep 17 00:00:00 2001 From: Hongtao Yu Date: Tue, 3 Sep 2024 16:02:24 -0700 Subject: [PATCH] Passing TMA descriptors through grid constant (#3066) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/163 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3066 Improving the TMA kernel by passing the TMA descriptors through grid constant. Grid constant (D61692148) significantly reduces kernel invocation overhead. Also enables bias for the TMA kernel. Reviewed By: sfzhu93 Differential Revision: D61799463 --- .../experimental/gemm/triton_gemm/fp8_gemm.py | 297 +++++++++++++----- 1 file changed, 214 insertions(+), 83 deletions(-) diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index ce2cd12a38..c8dba3b402 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -680,6 +680,25 @@ def _kernel_matmul_fp8_row_imprecise_acc( tl.atomic_add(C, acc, mask=mask) +@triton.autotune( + configs=[ + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + ], + key=[ + "m_key", + "n_key", + "k_key", + ], +) +@triton.heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } +) @triton.jit def _kernel_matmul_fp8_row_tma_persistent( A_ptr, @@ -688,8 +707,12 @@ def _kernel_matmul_fp8_row_tma_persistent( M, N, K, + m_key, + n_key, + k_key, A_scale, B_scale, + Bias, stride_am, stride_ak, stride_bn, @@ -704,7 +727,10 @@ def _kernel_matmul_fp8_row_tma_persistent( BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, AB_DTYPE: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, NUM_SMS: tl.constexpr, + USE_BIAS: tl.constexpr, ) -> None: """Matmul kernel of [M, K] @ [N, K] with row-wise scales @@ -761,6 +787,7 @@ def _kernel_matmul_fp8_row_tma_persistent( dtype_fp8 = tl.float8e4nv scale_dtype = tl.float32 + bias_dtype = tl.float32 for _ in range(0, k_tiles * tiles_per_SM): ki = tl.where(ki == k_tiles - 1, 0, ki + 1) @@ -785,29 +812,128 @@ def _kernel_matmul_fp8_row_tma_persistent( b = tl._experimental_descriptor_load( B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8 ) - acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + + if fp8_fast_accum: + acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + else: + acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) if ki == k_tiles - 1: # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M - rn = pid_n * BLOCK_N # # Invert scaling. a_scale = tl._experimental_descriptor_load( - A_scale, [rm], [BLOCK_M], scale_dtype + A_scale, [offs_am], [BLOCK_M], scale_dtype ) b_scale = tl._experimental_descriptor_load( - B_scale, [rn], [BLOCK_N], scale_dtype + B_scale, [offs_bn], [BLOCK_N], scale_dtype ) # pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`. scale = a_scale[:, None] * b_scale[None, :] acc *= scale + + # Load and add bias if specified. + if USE_BIAS: + bias = tl._experimental_descriptor_load( + Bias, [offs_bn], [BLOCK_N], bias_dtype + ) + acc += bias[None, :] + acc = acc.to(C_ptr.dtype.element_ty) - tl._experimental_descriptor_store(C_ptr, acc, [rm, rn]) + tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn]) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) +# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498). +HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) + +if HAS_TMA_DESC: + print( + "TMA benchmarks will be running with experimental grid constant TMA descriptor." + ) +else: + print("TMA benchmarks will be running without grid constant TMA descriptor.") + + +class TmaAutoTuneHelper: + + # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 + class KernelParamWrapper: + def __init__(self, desc): + self.desc = desc + + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + TMA_SIZE = 128 + + def __init__(self): + self.fill_1d_tma_descriptor_inner = ( + triton.runtime.driver.active.utils.fill_1d_tma_descriptor + ) + self.fill_2d_tma_descriptor_inner = ( + triton.runtime.driver.active.utils.fill_2d_tma_descriptor + ) + if HAS_TMA_DESC: + self.descriptors = {} + else: + self.cuda_descriptors = {} + + # Call this method outside of the lambda function for grid size + def init_tma_descriptor(self, name): + if HAS_TMA_DESC: + self.descriptors[name] = torch.empty( + TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 + ) + else: + self.cuda_descriptors[name] = torch.empty( + TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 + ) + + # Call this method inside the lambda function for grid size + def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, desc_x.data_ptr() + ) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, buf_x.data_ptr() + ) + desc_x.copy_(buf_x, non_blocking=True) + + # Call this method inside the lambda function for grid size + def fill_2d_tma_descriptor( + self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size + ): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() + ) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() + ) + desc_x.copy_(buf_x, non_blocking=True) + + def get_tma_descriptor_kernel_param(self, name): + if HAS_TMA_DESC: + assert self.descriptors[name] is not None + return self.KernelParamWrapper(self.descriptors[name]) + else: + assert self.cuda_descriptors[name] is not None + return self.cuda_descriptors[name] + + @torch.library.custom_op("triton::matmul_fp8_row", mutates_args=()) def matmul_fp8_row( a: torch.Tensor, @@ -880,85 +1006,96 @@ def persistent_grid(META): ) if tma_persistent: - if bias is not None: - raise NotImplementedError("TMA persistent kernel doesn't support bias yet") - # used by TMA persistent kernel - TMA_SIZE = 128 - - # autotune doesn't work with TMA - # https://github.com/triton-lang/triton/blob/main/python/tutorials/09-persistent-matmul.py#L312 - - BLOCK_M = 128 - BLOCK_N = 256 - BLOCK_K = 128 - GROUP_M = 8 - num_stages = 3 - num_warps = 8 - - desc_a = torch.empty(TMA_SIZE, dtype=torch.int8) - desc_b = torch.empty(TMA_SIZE, dtype=torch.int8) - desc_c = torch.empty(TMA_SIZE, dtype=torch.int8) - desc_a_scale = torch.empty(TMA_SIZE, dtype=torch.int8) - desc_b_scale = torch.empty(TMA_SIZE, dtype=torch.int8) - - triton.runtime.driver.active.utils.fill_2d_tma_descriptor( - a_tl.data_ptr(), - M, - K, - BLOCK_M, - BLOCK_K, - a_tl.element_size(), - desc_a.data_ptr(), - ) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor( - b_tl.data_ptr(), - N, - K, - BLOCK_N, - BLOCK_K, - b_tl.element_size(), - desc_b.data_ptr(), - ) - triton.runtime.driver.active.utils.fill_2d_tma_descriptor( - c.data_ptr(), - M, - N, - BLOCK_M, - BLOCK_N, - c.element_size(), - desc_c.data_ptr(), - ) - triton.runtime.driver.active.utils.fill_1d_tma_descriptor( - a_scale.data_ptr(), - M, - BLOCK_M, - a_scale.element_size(), - desc_a_scale.data_ptr(), - ) - triton.runtime.driver.active.utils.fill_1d_tma_descriptor( - b_scale.data_ptr(), - N, - BLOCK_N, - b_scale.element_size(), - desc_b_scale.data_ptr(), - ) - desc_a = torch.tensor(desc_a, device="cuda") - desc_b = torch.tensor(desc_b, device="cuda") - desc_c = torch.tensor(desc_c, device="cuda") - desc_a_scale = torch.tensor(desc_a_scale, device="cuda") - desc_b_scale = torch.tensor(desc_b_scale, device="cuda") + desc_helper = TmaAutoTuneHelper() + desc_helper.init_tma_descriptor("a") + desc_helper.init_tma_descriptor("b") + desc_helper.init_tma_descriptor("c") + desc_helper.init_tma_descriptor("a_scale") + desc_helper.init_tma_descriptor("b_scale") + desc_helper.init_tma_descriptor("bias") + + def persistent_grid_tma(META): + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "a", + a_tl.data_ptr(), + M, + K, + META["BLOCK_M"], + META["BLOCK_K"], + a_tl.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "b", + b_tl.data_ptr(), + N, + K, + META["BLOCK_N"], + META["BLOCK_K"], + b_tl.element_size(), + ) + desc_helper.fill_2d_tma_descriptor( + "c", + c.data_ptr(), + M, + N, + META["BLOCK_M"], + META["BLOCK_N"], + c.element_size(), + ) + desc_helper.fill_1d_tma_descriptor( + "a_scale", + a_scale.data_ptr(), + M, + META["BLOCK_M"], + a_scale.element_size(), + ) + desc_helper.fill_1d_tma_descriptor( + "b_scale", + b_scale.data_ptr(), + N, + META["BLOCK_N"], + b_scale.element_size(), + ) + if bias is not None: + desc_helper.fill_1d_tma_descriptor( + "bias", + bias.data_ptr(), + N, + META["BLOCK_N"], + bias.element_size(), + ) + return ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + ), + ) + + desc_a = desc_helper.get_tma_descriptor_kernel_param("a") + desc_b = desc_helper.get_tma_descriptor_kernel_param("b") + desc_c = desc_helper.get_tma_descriptor_kernel_param("c") + desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale") + desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale") + desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias") # pyre-ignore[28]: - _kernel_matmul_fp8_row_tma_persistent[persistent_grid]( + _kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma]( desc_a, desc_b, desc_c, + # c, M, N, K, + m_key, + n_key, + k_key, desc_a_scale, desc_b_scale, + desc_bias, a.stride(0), a.stride(1), b.stride(0), @@ -968,18 +1105,12 @@ def persistent_grid(META): dot_out_dtype=dot_out_dtype_triton, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_K=BLOCK_K, - GROUP_M=GROUP_M, + GROUP_M=8, AB_DTYPE=False, NUM_SMS=NUM_SMS, - num_stages=num_stages, - num_warps=num_warps, + USE_BIAS=bias is not None, ) - return c.view(output_shape) - - if imprecise_acc: + elif imprecise_acc: _kernel_matmul_fp8_row_imprecise_acc[grid]( a_tl, b_tl,