From a586bd10a6b26ea83cdaa2fa0b1f027d165fbc1c Mon Sep 17 00:00:00 2001 From: Xiaohu Guo Date: Wed, 27 Nov 2024 10:24:41 -0600 Subject: [PATCH 1/4] rmsnorm opt for M=1 --- python/perf-kernels/rmsnorm.py | 54 +++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/python/perf-kernels/rmsnorm.py b/python/perf-kernels/rmsnorm.py index a04408b9cfd5..198a506380ac 100644 --- a/python/perf-kernels/rmsnorm.py +++ b/python/perf-kernels/rmsnorm.py @@ -25,15 +25,18 @@ def get_cuda_autotune_config(): def get_hip_autotune_config(): return [ - triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), - triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), - triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), - triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), - triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), - triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), - triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), - triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), - triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), + triton.Config({'waves_per_eu': 0}, num_warps=4, num_stages=2), + triton.Config({'waves_per_eu': 0}, num_warps=8, num_stages=2), + triton.Config({'waves_per_eu': 0}, num_warps=16, num_stages=2), + triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=2), + triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=2), + triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=2), + triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=2), + triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=2), + triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=2), + triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=2), + triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=2), + triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=2), ] @@ -47,12 +50,13 @@ def get_autotune_config(): @triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) @triton.jit def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, - BLOCK_SIZE: tl.constexpr): + BLOCK_SIZE: tl.constexpr, NUM_PRGMS: tl.constexpr): row_start = tl.program_id(0) - row_step = tl.num_programs(0) col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols - for row_idx in tl.range(row_start, n_rows, row_step): + tl.assume(input_row_stride >= 0) + tl.assume(output_row_stride >= 0) + for row_idx in tl.range(row_start, n_rows, NUM_PRGMS): row_start_ptr = input_ptr + row_idx * input_row_stride input_ptrs = row_start_ptr + col_offsets input_ptrs = tl.multiple_of(input_ptrs, (16, )) @@ -72,15 +76,12 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride tl.store(output_ptrs, rms_norm, mask=mask) -def triton_rmsnorm(x, g, epsilon=1e-6): - n_rows, n_cols = x.shape - BLOCK_SIZE = triton.next_power_of_2(n_cols) +def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6): + BLOCK_SIZE = blk_size - y = torch.empty_like(x, device='cuda') - - num_programs = n_rows - grid = lambda meta: (num_programs, ) - rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE) + NUM_PRGMS = n_rows + grid = lambda meta: (NUM_PRGMS, ) + rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, NUM_PRGMS) return y @@ -107,8 +108,9 @@ def torch_rmsnorm(x, g): def test_rmsnorm(M, N): torch.manual_seed(0) x = torch.randn(M, N, device='cuda') + y = torch.zeros_like(x, device='cuda') g = torch.ones((1, N), device='cuda') - y_triton = triton_rmsnorm(x, g) + y_triton = triton_rmsnorm(x, y, g) y_torch = torch_rmsnorm(x, g) @@ -157,13 +159,16 @@ def run_benchmark(args): @triton.testing.perf_report(config) def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=dtype) + y = torch.zeros_like(x, device='cuda') + n_rows, n_cols = x.shape + blk_size = triton.next_power_of_2(n_cols) stream = torch.cuda.Stream() torch.cuda.set_stream(stream) g = torch.ones((1, N), device='cuda') if provider == 'torch': ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g)) if provider == 'triton': - ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, g)) + ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) @@ -194,9 +199,10 @@ def parse_args(): def main(): args = parse_args() if args.no_benchmark: - x = torch.randn(args.M_start, args.N_start) + x = torch.randn(args.M_start, args.N_start, device='cuda') + y = torch.zeros_like(x, device='cuda') g = torch.ones((1, args.N_start), device='cuda') - triton_rmsnorm(x, g) + triton_rmsnorm(x, y, g) else: run_benchmark(args) From 4589d5be1d1a47b5eeaafe26f30cc11403745d45 Mon Sep 17 00:00:00 2001 From: Xiaohu Guo Date: Wed, 27 Nov 2024 11:24:07 -0600 Subject: [PATCH 2/4] forgot update the pytest --- python/perf-kernels/rmsnorm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/perf-kernels/rmsnorm.py b/python/perf-kernels/rmsnorm.py index 198a506380ac..45ec631e3728 100644 --- a/python/perf-kernels/rmsnorm.py +++ b/python/perf-kernels/rmsnorm.py @@ -109,8 +109,10 @@ def test_rmsnorm(M, N): torch.manual_seed(0) x = torch.randn(M, N, device='cuda') y = torch.zeros_like(x, device='cuda') + n_rows, n_cols = x.shape + blk_size = triton.next_power_of_2(n_cols) g = torch.ones((1, N), device='cuda') - y_triton = triton_rmsnorm(x, y, g) + y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) y_torch = torch_rmsnorm(x, g) From bf4d4bb8b75d8b7df8c247bc0ce2e756972c1edf Mon Sep 17 00:00:00 2001 From: Xiaohu Guo Date: Mon, 2 Dec 2024 08:57:10 -0600 Subject: [PATCH 3/4] fix --no_benchmark option --- python/perf-kernels/rmsnorm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/perf-kernels/rmsnorm.py b/python/perf-kernels/rmsnorm.py index 45ec631e3728..e313ab162675 100644 --- a/python/perf-kernels/rmsnorm.py +++ b/python/perf-kernels/rmsnorm.py @@ -203,8 +203,10 @@ def main(): if args.no_benchmark: x = torch.randn(args.M_start, args.N_start, device='cuda') y = torch.zeros_like(x, device='cuda') + n_rows, n_cols = x.shape + blk_size = triton.next_power_of_2(n_cols) g = torch.ones((1, args.N_start), device='cuda') - triton_rmsnorm(x, y, g) + triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) else: run_benchmark(args) From d114cf9cece632afd2fbcb6ed2596a3bcb7465f2 Mon Sep 17 00:00:00 2001 From: Xiaohu Guo Date: Mon, 2 Dec 2024 09:59:42 -0600 Subject: [PATCH 4/4] tidy autotune configs and add verbose option for best configs of autotune --- python/perf-kernels/rmsnorm.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/python/perf-kernels/rmsnorm.py b/python/perf-kernels/rmsnorm.py index e313ab162675..4c6bbda135b9 100644 --- a/python/perf-kernels/rmsnorm.py +++ b/python/perf-kernels/rmsnorm.py @@ -2,6 +2,7 @@ import torch import sys import pytest +from itertools import product import triton import triton.language as tl @@ -25,18 +26,7 @@ def get_cuda_autotune_config(): def get_hip_autotune_config(): return [ - triton.Config({'waves_per_eu': 0}, num_warps=4, num_stages=2), - triton.Config({'waves_per_eu': 0}, num_warps=8, num_stages=2), - triton.Config({'waves_per_eu': 0}, num_warps=16, num_stages=2), - triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=2), - triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=2), - triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=2), - triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=2), - triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=2), - triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=2), - triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=2), - triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=2), - triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=2), + triton.Config({'waves_per_eu': we}, num_warps=nw, num_stages=2) for (we, nw) in product([0, 1, 2, 4], [8, 16]) ] @@ -171,6 +161,9 @@ def benchmark(M, N, provider): ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g)) if provider == 'triton': ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)) + global verbose + if verbose: + print(f'SIZE: {N} Best tuning config: ({rms_kernel.best_config})') gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) @@ -194,12 +187,14 @@ def parse_args(): parser.add_argument('-d', "--dtype", default="fp16") parser.add_argument('-nb', "--no_benchmark", default=False, type=bool) + parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") return parser.parse_args() def main(): args = parse_args() + global verbose if args.no_benchmark: x = torch.randn(args.M_start, args.N_start, device='cuda') y = torch.zeros_like(x, device='cuda') @@ -208,6 +203,7 @@ def main(): g = torch.ones((1, args.N_start), device='cuda') triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) else: + verbose = args.v run_benchmark(args)