Skip to content

Commit

Permalink
Passing TMA descriptors through grid constant (#3066)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#163

Pull Request resolved: #3066

Improving the TMA kernel by passing the TMA descriptors through grid constant. Grid constant (D61692148) significantly reduces kernel invocation overhead.

Also enables bias for the TMA kernel.

Reviewed By: sfzhu93

Differential Revision: D61799463
  • Loading branch information
htyu authored and facebook-github-bot committed Sep 3, 2024
1 parent 225ac16 commit 58bf3de
Showing 1 changed file with 214 additions and 83 deletions.
297 changes: 214 additions & 83 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,25 @@ def _kernel_matmul_fp8_row_imprecise_acc(
tl.atomic_add(C, acc, mask=mask)


@triton.autotune(
configs=[
Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
],
key=[
"m_key",
"n_key",
"k_key",
],
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def _kernel_matmul_fp8_row_tma_persistent(
A_ptr,
Expand All @@ -688,8 +707,12 @@ def _kernel_matmul_fp8_row_tma_persistent(
M,
N,
K,
m_key,
n_key,
k_key,
A_scale,
B_scale,
Bias,
stride_am,
stride_ak,
stride_bn,
Expand All @@ -704,7 +727,10 @@ def _kernel_matmul_fp8_row_tma_persistent(
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
AB_DTYPE: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_BIAS: tl.constexpr,
) -> None:
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales
Expand Down Expand Up @@ -761,6 +787,7 @@ def _kernel_matmul_fp8_row_tma_persistent(

dtype_fp8 = tl.float8e4nv
scale_dtype = tl.float32
bias_dtype = tl.float32

for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
Expand All @@ -785,29 +812,128 @@ def _kernel_matmul_fp8_row_tma_persistent(
b = tl._experimental_descriptor_load(
B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
)
acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)

if fp8_fast_accum:
acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
else:
acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)

if ki == k_tiles - 1:
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M
rn = pid_n * BLOCK_N

# # Invert scaling.
a_scale = tl._experimental_descriptor_load(
A_scale, [rm], [BLOCK_M], scale_dtype
A_scale, [offs_am], [BLOCK_M], scale_dtype
)
b_scale = tl._experimental_descriptor_load(
B_scale, [rn], [BLOCK_N], scale_dtype
B_scale, [offs_bn], [BLOCK_N], scale_dtype
)
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
scale = a_scale[:, None] * b_scale[None, :]
acc *= scale

# Load and add bias if specified.
if USE_BIAS:
bias = tl._experimental_descriptor_load(
Bias, [offs_bn], [BLOCK_N], bias_dtype
)
acc += bias[None, :]

acc = acc.to(C_ptr.dtype.element_ty)

tl._experimental_descriptor_store(C_ptr, acc, [rm, rn])
tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)


# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)

if HAS_TMA_DESC:
print(
"TMA benchmarks will be running with experimental grid constant TMA descriptor."
)
else:
print("TMA benchmarks will be running without grid constant TMA descriptor.")


class TmaAutoTuneHelper:

# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc

def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()

TMA_SIZE = 128

def __init__(self):
self.fill_1d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
)
self.fill_2d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}

# Call this method outside of the lambda function for grid size
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
)
else:
self.cuda_descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
)

# Call this method inside the lambda function for grid size
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)

# Call this method inside the lambda function for grid size
def fill_2d_tma_descriptor(
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)

def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]


@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=())
def matmul_fp8_row(
a: torch.Tensor,
Expand Down Expand Up @@ -880,85 +1006,96 @@ def persistent_grid(META):
)

if tma_persistent:
if bias is not None:
raise NotImplementedError("TMA persistent kernel doesn't support bias yet")

# used by TMA persistent kernel
TMA_SIZE = 128

# autotune doesn't work with TMA
# https://github.com/triton-lang/triton/blob/main/python/tutorials/09-persistent-matmul.py#L312

BLOCK_M = 128
BLOCK_N = 256
BLOCK_K = 128
GROUP_M = 8
num_stages = 3
num_warps = 8

desc_a = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_b = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_c = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_a_scale = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_b_scale = torch.empty(TMA_SIZE, dtype=torch.int8)

triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
a_tl.data_ptr(),
M,
K,
BLOCK_M,
BLOCK_K,
a_tl.element_size(),
desc_a.data_ptr(),
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
b_tl.data_ptr(),
N,
K,
BLOCK_N,
BLOCK_K,
b_tl.element_size(),
desc_b.data_ptr(),
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
c.data_ptr(),
M,
N,
BLOCK_M,
BLOCK_N,
c.element_size(),
desc_c.data_ptr(),
)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(
a_scale.data_ptr(),
M,
BLOCK_M,
a_scale.element_size(),
desc_a_scale.data_ptr(),
)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(
b_scale.data_ptr(),
N,
BLOCK_N,
b_scale.element_size(),
desc_b_scale.data_ptr(),
)
desc_a = torch.tensor(desc_a, device="cuda")
desc_b = torch.tensor(desc_b, device="cuda")
desc_c = torch.tensor(desc_c, device="cuda")
desc_a_scale = torch.tensor(desc_a_scale, device="cuda")
desc_b_scale = torch.tensor(desc_b_scale, device="cuda")
desc_helper = TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("a")
desc_helper.init_tma_descriptor("b")
desc_helper.init_tma_descriptor("c")
desc_helper.init_tma_descriptor("a_scale")
desc_helper.init_tma_descriptor("b_scale")
desc_helper.init_tma_descriptor("bias")

def persistent_grid_tma(META):
nonlocal desc_helper
desc_helper.fill_2d_tma_descriptor(
"a",
a_tl.data_ptr(),
M,
K,
META["BLOCK_M"],
META["BLOCK_K"],
a_tl.element_size(),
)

desc_helper.fill_2d_tma_descriptor(
"b",
b_tl.data_ptr(),
N,
K,
META["BLOCK_N"],
META["BLOCK_K"],
b_tl.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"c",
c.data_ptr(),
M,
N,
META["BLOCK_M"],
META["BLOCK_N"],
c.element_size(),
)
desc_helper.fill_1d_tma_descriptor(
"a_scale",
a_scale.data_ptr(),
M,
META["BLOCK_M"],
a_scale.element_size(),
)
desc_helper.fill_1d_tma_descriptor(
"b_scale",
b_scale.data_ptr(),
N,
META["BLOCK_N"],
b_scale.element_size(),
)
if bias is not None:
desc_helper.fill_1d_tma_descriptor(
"bias",
bias.data_ptr(),
N,
META["BLOCK_N"],
bias.element_size(),
)
return (
min(
NUM_SMS,
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
),
)

desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")

# pyre-ignore[28]:
_kernel_matmul_fp8_row_tma_persistent[persistent_grid](
_kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma](
desc_a,
desc_b,
desc_c,
# c,
M,
N,
K,
m_key,
n_key,
k_key,
desc_a_scale,
desc_b_scale,
desc_bias,
a.stride(0),
a.stride(1),
b.stride(0),
Expand All @@ -968,18 +1105,12 @@ def persistent_grid(META):
dot_out_dtype=dot_out_dtype_triton,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
GROUP_M=GROUP_M,
GROUP_M=8,
AB_DTYPE=False,
NUM_SMS=NUM_SMS,
num_stages=num_stages,
num_warps=num_warps,
USE_BIAS=bias is not None,
)
return c.view(output_shape)

if imprecise_acc:
elif imprecise_acc:
_kernel_matmul_fp8_row_imprecise_acc[grid](
a_tl,
b_tl,
Expand Down

0 comments on commit 58bf3de

Please sign in to comment.