-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rmsnorm optimization for M = 1 #668
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
import torch | ||
import sys | ||
import pytest | ||
from itertools import product | ||
|
||
import triton | ||
import triton.language as tl | ||
|
@@ -25,15 +26,7 @@ def get_cuda_autotune_config(): | |
|
||
def get_hip_autotune_config(): | ||
return [ | ||
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': we}, num_warps=nw, num_stages=2) for (we, nw) in product([0, 1, 2, 4], [8, 16]) | ||
] | ||
|
||
|
||
|
@@ -47,12 +40,13 @@ def get_autotune_config(): | |
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) | ||
@triton.jit | ||
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, | ||
BLOCK_SIZE: tl.constexpr): | ||
BLOCK_SIZE: tl.constexpr, NUM_PRGMS: tl.constexpr): | ||
row_start = tl.program_id(0) | ||
row_step = tl.num_programs(0) | ||
col_offsets = tl.arange(0, BLOCK_SIZE) | ||
mask = col_offsets < n_cols | ||
for row_idx in tl.range(row_start, n_rows, row_step): | ||
tl.assume(input_row_stride >= 0) | ||
tl.assume(output_row_stride >= 0) | ||
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS): | ||
row_start_ptr = input_ptr + row_idx * input_row_stride | ||
input_ptrs = row_start_ptr + col_offsets | ||
input_ptrs = tl.multiple_of(input_ptrs, (16, )) | ||
|
@@ -72,15 +66,12 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride | |
tl.store(output_ptrs, rms_norm, mask=mask) | ||
|
||
|
||
def triton_rmsnorm(x, g, epsilon=1e-6): | ||
n_rows, n_cols = x.shape | ||
BLOCK_SIZE = triton.next_power_of_2(n_cols) | ||
|
||
y = torch.empty_like(x, device='cuda') | ||
def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6): | ||
BLOCK_SIZE = blk_size | ||
|
||
num_programs = n_rows | ||
grid = lambda meta: (num_programs, ) | ||
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE) | ||
NUM_PRGMS = n_rows | ||
grid = lambda meta: (NUM_PRGMS, ) | ||
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, NUM_PRGMS) | ||
|
||
return y | ||
|
||
|
@@ -107,8 +98,11 @@ def torch_rmsnorm(x, g): | |
def test_rmsnorm(M, N): | ||
torch.manual_seed(0) | ||
x = torch.randn(M, N, device='cuda') | ||
y = torch.zeros_like(x, device='cuda') | ||
n_rows, n_cols = x.shape | ||
blk_size = triton.next_power_of_2(n_cols) | ||
g = torch.ones((1, N), device='cuda') | ||
y_triton = triton_rmsnorm(x, g) | ||
y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) | ||
|
||
y_torch = torch_rmsnorm(x, g) | ||
|
||
|
@@ -157,13 +151,19 @@ def run_benchmark(args): | |
@triton.testing.perf_report(config) | ||
def benchmark(M, N, provider): | ||
x = torch.randn(M, N, device='cuda', dtype=dtype) | ||
y = torch.zeros_like(x, device='cuda') | ||
n_rows, n_cols = x.shape | ||
blk_size = triton.next_power_of_2(n_cols) | ||
stream = torch.cuda.Stream() | ||
torch.cuda.set_stream(stream) | ||
g = torch.ones((1, N), device='cuda') | ||
if provider == 'torch': | ||
ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g)) | ||
if provider == 'triton': | ||
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, g)) | ||
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)) | ||
global verbose | ||
if verbose: | ||
print(f'SIZE: {N} Best tuning config: ({rms_kernel.best_config})') | ||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) | ||
return gbps(ms) | ||
|
||
|
@@ -187,17 +187,23 @@ def parse_args(): | |
|
||
parser.add_argument('-d', "--dtype", default="fp16") | ||
parser.add_argument('-nb', "--no_benchmark", default=False, type=bool) | ||
parser.add_argument("-v", action='store_true', default=False, help="Print out the best tuning config") | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
global verbose | ||
if args.no_benchmark: | ||
x = torch.randn(args.M_start, args.N_start) | ||
x = torch.randn(args.M_start, args.N_start, device='cuda') | ||
y = torch.zeros_like(x, device='cuda') | ||
n_rows, n_cols = x.shape | ||
blk_size = triton.next_power_of_2(n_cols) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xiaohuguo2023 Did you notice any big performance drop if blk_size >65k? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, with blk_size>65k, start to have vgpr spills There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my next PR will address this issue |
||
g = torch.ones((1, args.N_start), device='cuda') | ||
triton_rmsnorm(x, g) | ||
triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) | ||
else: | ||
verbose = args.v | ||
run_benchmark(args) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move the line
g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
before the loop, so we do not have to load it for each loop? Or the compiler can do that for us?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why, but put in the loop gives slightly improved perf in average
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, really? if that is the case, we can put it in the loop.