diff --git a/python/perf-kernels/rmsnorm.py b/python/perf-kernels/rmsnorm.py index a04408b9cfd5..4c6bbda135b9 100644 --- a/python/perf-kernels/rmsnorm.py +++ b/python/perf-kernels/rmsnorm.py @@ -2,6 +2,7 @@ import torch import sys import pytest +from itertools import product import triton import triton.language as tl @@ -25,15 +26,7 @@ def get_cuda_autotune_config(): def get_hip_autotune_config(): return [ - triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), - triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), - triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), - triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), - triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), - triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), - triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), - triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), - triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), + triton.Config({'waves_per_eu': we}, num_warps=nw, num_stages=2) for (we, nw) in product([0, 1, 2, 4], [8, 16]) ] @@ -47,12 +40,13 @@ def get_autotune_config(): @triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) @triton.jit def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, - BLOCK_SIZE: tl.constexpr): + BLOCK_SIZE: tl.constexpr, NUM_PRGMS: tl.constexpr): row_start = tl.program_id(0) - row_step = tl.num_programs(0) col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - for row_idx in tl.range(row_start, n_rows, row_step): + tl.assume(input_row_stride >= 0) + tl.assume(output_row_stride >= 0) + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS): row_start_ptr = input_ptr + row_idx * input_row_stride input_ptrs = row_start_ptr + col_offsets input_ptrs = tl.multiple_of(input_ptrs, (16, )) @@ -72,15 +66,12 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride tl.store(output_ptrs, rms_norm, mask=mask) -def triton_rmsnorm(x, g, epsilon=1e-6): - n_rows, n_cols = x.shape - BLOCK_SIZE = triton.next_power_of_2(n_cols) - - y = torch.empty_like(x, device='cuda') +def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6): + BLOCK_SIZE = blk_size - num_programs = n_rows - grid = lambda meta: (num_programs, ) - rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE) + NUM_PRGMS = n_rows + grid = lambda meta: (NUM_PRGMS, ) + rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, NUM_PRGMS) return y @@ -107,8 +98,11 @@ def torch_rmsnorm(x, g): def test_rmsnorm(M, N): torch.manual_seed(0) x = torch.randn(M, N, device='cuda') + y = torch.zeros_like(x, device='cuda') + n_rows, n_cols = x.shape + blk_size = triton.next_power_of_2(n_cols) g = torch.ones((1, N), device='cuda') - y_triton = triton_rmsnorm(x, g) + y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) y_torch = torch_rmsnorm(x, g) @@ -157,13 +151,19 @@ def run_benchmark(args): @triton.testing.perf_report(config) def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=dtype) + y = torch.zeros_like(x, device='cuda') + n_rows, n_cols = x.shape + blk_size = triton.next_power_of_2(n_cols) stream = torch.cuda.Stream() torch.cuda.set_stream(stream) g = torch.ones((1, N), device='cuda') if provider == 'torch': ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g)) if provider == 'triton': - ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, g)) + ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)) + global verbose + if verbose: + print(f'SIZE: {N} Best tuning config: ({rms_kernel.best_config})') gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) @@ -187,17 +187,23 @@ def parse_args(): parser.add_argument('-d', "--dtype", default="fp16") parser.add_argument('-nb', "--no_benchmark", default=False, type=bool) + parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") return parser.parse_args() def main(): args = parse_args() + global verbose if args.no_benchmark: - x = torch.randn(args.M_start, args.N_start) + x = torch.randn(args.M_start, args.N_start, device='cuda') + y = torch.zeros_like(x, device='cuda') + n_rows, n_cols = x.shape + blk_size = triton.next_power_of_2(n_cols) g = torch.ones((1, args.N_start), device='cuda') - triton_rmsnorm(x, g) + triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) else: + verbose = args.v run_benchmark(args)