Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing TMA descriptors through grid constant #3066

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 214 additions & 83 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,25 @@ def _kernel_matmul_fp8_row_imprecise_acc(
tl.atomic_add(C, acc, mask=mask)


@triton.autotune(
configs=[
Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
num_stages=3,
num_warps=8,
),
],
key=[
"m_key",
"n_key",
"k_key",
],
)
@triton.heuristics(
{
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
}
)
@triton.jit
def _kernel_matmul_fp8_row_tma_persistent(
A_ptr,
Expand All @@ -688,8 +707,12 @@ def _kernel_matmul_fp8_row_tma_persistent(
M,
N,
K,
m_key,
n_key,
k_key,
A_scale,
B_scale,
Bias,
stride_am,
stride_ak,
stride_bn,
Expand All @@ -704,7 +727,10 @@ def _kernel_matmul_fp8_row_tma_persistent(
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
AB_DTYPE: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_BIAS: tl.constexpr,
) -> None:
"""Matmul kernel of [M, K] @ [N, K] with row-wise scales

Expand Down Expand Up @@ -761,6 +787,7 @@ def _kernel_matmul_fp8_row_tma_persistent(

dtype_fp8 = tl.float8e4nv
scale_dtype = tl.float32
bias_dtype = tl.float32

for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
Expand All @@ -785,29 +812,128 @@ def _kernel_matmul_fp8_row_tma_persistent(
b = tl._experimental_descriptor_load(
B_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], dtype_fp8
)
acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)

if fp8_fast_accum:
acc = tl.dot(a, b.T, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
else:
acc += tl.dot(a, b.T, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)

if ki == k_tiles - 1:
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M
rn = pid_n * BLOCK_N

# # Invert scaling.
a_scale = tl._experimental_descriptor_load(
A_scale, [rm], [BLOCK_M], scale_dtype
A_scale, [offs_am], [BLOCK_M], scale_dtype
)
b_scale = tl._experimental_descriptor_load(
B_scale, [rn], [BLOCK_N], scale_dtype
B_scale, [offs_bn], [BLOCK_N], scale_dtype
)
# pyre-ignore[16]: Undefined attribute [16]: `float` has no attribute `__getitem__`.
scale = a_scale[:, None] * b_scale[None, :]
acc *= scale

# Load and add bias if specified.
if USE_BIAS:
bias = tl._experimental_descriptor_load(
Bias, [offs_bn], [BLOCK_N], bias_dtype
)
acc += bias[None, :]

acc = acc.to(C_ptr.dtype.element_ty)

tl._experimental_descriptor_store(C_ptr, acc, [rm, rn])
tl._experimental_descriptor_store(C_ptr, acc, [offs_am, offs_bn])
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype)


# check if we have the TMA version in Triton PR #4498 (https://github.com/triton-lang/triton/pull/4498).
HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl)

if HAS_TMA_DESC:
print(
"TMA benchmarks will be running with experimental grid constant TMA descriptor."
)
else:
print("TMA benchmarks will be running without grid constant TMA descriptor.")


class TmaAutoTuneHelper:

# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
class KernelParamWrapper:
def __init__(self, desc):
self.desc = desc

def tma_desc_cpu_ptr(self):
return self.desc.data_ptr()

TMA_SIZE = 128

def __init__(self):
self.fill_1d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_1d_tma_descriptor
)
self.fill_2d_tma_descriptor_inner = (
triton.runtime.driver.active.utils.fill_2d_tma_descriptor
)
if HAS_TMA_DESC:
self.descriptors = {}
else:
self.cuda_descriptors = {}

# Call this method outside of the lambda function for grid size
def init_tma_descriptor(self, name):
if HAS_TMA_DESC:
self.descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8
)
else:
self.cuda_descriptors[name] = torch.empty(
TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8
)

# Call this method inside the lambda function for grid size
def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(
ptr, dim, block_dim, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)

# Call this method inside the lambda function for grid size
def fill_2d_tma_descriptor(
self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size
):
if HAS_TMA_DESC:
desc_x = self.descriptors[name]
assert desc_x.data_ptr() % 64 == 0
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr()
)
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(
ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()
)
desc_x.copy_(buf_x, non_blocking=True)

def get_tma_descriptor_kernel_param(self, name):
if HAS_TMA_DESC:
assert self.descriptors[name] is not None
return self.KernelParamWrapper(self.descriptors[name])
else:
assert self.cuda_descriptors[name] is not None
return self.cuda_descriptors[name]


@torch.library.custom_op("triton::matmul_fp8_row", mutates_args=())
def matmul_fp8_row(
a: torch.Tensor,
Expand Down Expand Up @@ -880,85 +1006,96 @@ def persistent_grid(META):
)

if tma_persistent:
if bias is not None:
raise NotImplementedError("TMA persistent kernel doesn't support bias yet")

# used by TMA persistent kernel
TMA_SIZE = 128

# autotune doesn't work with TMA
# https://github.com/triton-lang/triton/blob/main/python/tutorials/09-persistent-matmul.py#L312

BLOCK_M = 128
BLOCK_N = 256
BLOCK_K = 128
GROUP_M = 8
num_stages = 3
num_warps = 8

desc_a = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_b = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_c = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_a_scale = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_b_scale = torch.empty(TMA_SIZE, dtype=torch.int8)

triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
a_tl.data_ptr(),
M,
K,
BLOCK_M,
BLOCK_K,
a_tl.element_size(),
desc_a.data_ptr(),
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
b_tl.data_ptr(),
N,
K,
BLOCK_N,
BLOCK_K,
b_tl.element_size(),
desc_b.data_ptr(),
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
c.data_ptr(),
M,
N,
BLOCK_M,
BLOCK_N,
c.element_size(),
desc_c.data_ptr(),
)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(
a_scale.data_ptr(),
M,
BLOCK_M,
a_scale.element_size(),
desc_a_scale.data_ptr(),
)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(
b_scale.data_ptr(),
N,
BLOCK_N,
b_scale.element_size(),
desc_b_scale.data_ptr(),
)
desc_a = torch.tensor(desc_a, device="cuda")
desc_b = torch.tensor(desc_b, device="cuda")
desc_c = torch.tensor(desc_c, device="cuda")
desc_a_scale = torch.tensor(desc_a_scale, device="cuda")
desc_b_scale = torch.tensor(desc_b_scale, device="cuda")
desc_helper = TmaAutoTuneHelper()
desc_helper.init_tma_descriptor("a")
desc_helper.init_tma_descriptor("b")
desc_helper.init_tma_descriptor("c")
desc_helper.init_tma_descriptor("a_scale")
desc_helper.init_tma_descriptor("b_scale")
desc_helper.init_tma_descriptor("bias")

def persistent_grid_tma(META):
nonlocal desc_helper
desc_helper.fill_2d_tma_descriptor(
"a",
a_tl.data_ptr(),
M,
K,
META["BLOCK_M"],
META["BLOCK_K"],
a_tl.element_size(),
)

desc_helper.fill_2d_tma_descriptor(
"b",
b_tl.data_ptr(),
N,
K,
META["BLOCK_N"],
META["BLOCK_K"],
b_tl.element_size(),
)
desc_helper.fill_2d_tma_descriptor(
"c",
c.data_ptr(),
M,
N,
META["BLOCK_M"],
META["BLOCK_N"],
c.element_size(),
)
desc_helper.fill_1d_tma_descriptor(
"a_scale",
a_scale.data_ptr(),
M,
META["BLOCK_M"],
a_scale.element_size(),
)
desc_helper.fill_1d_tma_descriptor(
"b_scale",
b_scale.data_ptr(),
N,
META["BLOCK_N"],
b_scale.element_size(),
)
if bias is not None:
desc_helper.fill_1d_tma_descriptor(
"bias",
bias.data_ptr(),
N,
META["BLOCK_N"],
bias.element_size(),
)
return (
min(
NUM_SMS,
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
),
)

desc_a = desc_helper.get_tma_descriptor_kernel_param("a")
desc_b = desc_helper.get_tma_descriptor_kernel_param("b")
desc_c = desc_helper.get_tma_descriptor_kernel_param("c")
desc_a_scale = desc_helper.get_tma_descriptor_kernel_param("a_scale")
desc_b_scale = desc_helper.get_tma_descriptor_kernel_param("b_scale")
desc_bias = desc_helper.get_tma_descriptor_kernel_param("bias")

# pyre-ignore[28]:
_kernel_matmul_fp8_row_tma_persistent[persistent_grid](
_kernel_matmul_fp8_row_tma_persistent[persistent_grid_tma](
desc_a,
desc_b,
desc_c,
# c,
M,
N,
K,
m_key,
n_key,
k_key,
desc_a_scale,
desc_b_scale,
desc_bias,
a.stride(0),
a.stride(1),
b.stride(0),
Expand All @@ -968,18 +1105,12 @@ def persistent_grid(META):
dot_out_dtype=dot_out_dtype_triton,
allow_tf32=allow_tf32,
fp8_fast_accum=fp8_fast_accum,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
GROUP_M=GROUP_M,
GROUP_M=8,
AB_DTYPE=False,
NUM_SMS=NUM_SMS,
num_stages=num_stages,
num_warps=num_warps,
USE_BIAS=bias is not None,
)
return c.view(output_shape)

if imprecise_acc:
elif imprecise_acc:
_kernel_matmul_fp8_row_imprecise_acc[grid](
a_tl,
b_tl,
Expand Down
Loading