Skip to content

Commit

Permalink
rmsnorm optimization for M = 1 (#668)
Browse files Browse the repository at this point in the history
* rmsnorm opt for M=1

* forgot update the pytest

* fix --no_benchmark option

* tidy autotune configs and add verbose option for best configs of autotune
  • Loading branch information
xiaohuguo2023 authored Dec 2, 2024
1 parent 6e7ad94 commit fc558e7
Showing 1 changed file with 30 additions and 24 deletions.
54 changes: 30 additions & 24 deletions python/perf-kernels/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
import sys
import pytest
from itertools import product

import triton
import triton.language as tl
Expand All @@ -25,15 +26,7 @@ def get_cuda_autotune_config():

def get_hip_autotune_config():
return [
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': we}, num_warps=nw, num_stages=2) for (we, nw) in product([0, 1, 2, 4], [8, 16])
]


Expand All @@ -47,12 +40,13 @@ def get_autotune_config():
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
BLOCK_SIZE: tl.constexpr):
BLOCK_SIZE: tl.constexpr, NUM_PRGMS: tl.constexpr):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
for row_idx in tl.range(row_start, n_rows, row_step):
tl.assume(input_row_stride >= 0)
tl.assume(output_row_stride >= 0)
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS):
row_start_ptr = input_ptr + row_idx * input_row_stride
input_ptrs = row_start_ptr + col_offsets
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
Expand All @@ -72,15 +66,12 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
tl.store(output_ptrs, rms_norm, mask=mask)


def triton_rmsnorm(x, g, epsilon=1e-6):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)

y = torch.empty_like(x, device='cuda')
def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6):
BLOCK_SIZE = blk_size

num_programs = n_rows
grid = lambda meta: (num_programs, )
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE)
NUM_PRGMS = n_rows
grid = lambda meta: (NUM_PRGMS, )
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, NUM_PRGMS)

return y

Expand All @@ -107,8 +98,11 @@ def torch_rmsnorm(x, g):
def test_rmsnorm(M, N):
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)
g = torch.ones((1, N), device='cuda')
y_triton = triton_rmsnorm(x, g)
y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)

y_torch = torch_rmsnorm(x, g)

Expand Down Expand Up @@ -157,13 +151,19 @@ def run_benchmark(args):
@triton.testing.perf_report(config)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=dtype)
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
g = torch.ones((1, N), device='cuda')
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, g))
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size))
global verbose
if verbose:
print(f'SIZE: {N} Best tuning config: ({rms_kernel.best_config})')
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)

Expand All @@ -187,17 +187,23 @@ def parse_args():

parser.add_argument('-d', "--dtype", default="fp16")
parser.add_argument('-nb', "--no_benchmark", default=False, type=bool)
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config")

return parser.parse_args()


def main():
args = parse_args()
global verbose
if args.no_benchmark:
x = torch.randn(args.M_start, args.N_start)
x = torch.randn(args.M_start, args.N_start, device='cuda')
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)
g = torch.ones((1, args.N_start), device='cuda')
triton_rmsnorm(x, g)
triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)
else:
verbose = args.v
run_benchmark(args)


Expand Down

0 comments on commit fc558e7

Please sign in to comment.