-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rmsnorm optimization for M = 1 #668
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,15 +25,18 @@ def get_cuda_autotune_config(): | |
|
||
def get_hip_autotune_config(): | ||
return [ | ||
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 0}, num_warps=4, num_stages=2), | ||
triton.Config({'waves_per_eu': 0}, num_warps=8, num_stages=2), | ||
triton.Config({'waves_per_eu': 0}, num_warps=16, num_stages=2), | ||
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=2), | ||
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=2), | ||
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=2), | ||
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=2), | ||
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=2), | ||
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=2), | ||
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=2), | ||
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=2), | ||
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=2), | ||
] | ||
|
||
|
||
|
@@ -47,12 +50,13 @@ def get_autotune_config(): | |
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True) | ||
@triton.jit | ||
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon, | ||
BLOCK_SIZE: tl.constexpr): | ||
BLOCK_SIZE: tl.constexpr, NUM_PRGMS: tl.constexpr): | ||
row_start = tl.program_id(0) | ||
row_step = tl.num_programs(0) | ||
col_offsets = tl.arange(0, BLOCK_SIZE) | ||
mask = col_offsets < n_cols | ||
for row_idx in tl.range(row_start, n_rows, row_step): | ||
tl.assume(input_row_stride >= 0) | ||
tl.assume(output_row_stride >= 0) | ||
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move the line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure why, but put in the loop gives slightly improved perf in average
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, really? if that is the case, we can put it in the loop. |
||
row_start_ptr = input_ptr + row_idx * input_row_stride | ||
input_ptrs = row_start_ptr + col_offsets | ||
input_ptrs = tl.multiple_of(input_ptrs, (16, )) | ||
|
@@ -72,15 +76,12 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride | |
tl.store(output_ptrs, rms_norm, mask=mask) | ||
|
||
|
||
def triton_rmsnorm(x, g, epsilon=1e-6): | ||
n_rows, n_cols = x.shape | ||
BLOCK_SIZE = triton.next_power_of_2(n_cols) | ||
|
||
y = torch.empty_like(x, device='cuda') | ||
def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6): | ||
BLOCK_SIZE = blk_size | ||
|
||
num_programs = n_rows | ||
grid = lambda meta: (num_programs, ) | ||
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE) | ||
NUM_PRGMS = n_rows | ||
grid = lambda meta: (NUM_PRGMS, ) | ||
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, NUM_PRGMS) | ||
|
||
return y | ||
|
||
|
@@ -107,8 +108,11 @@ def torch_rmsnorm(x, g): | |
def test_rmsnorm(M, N): | ||
torch.manual_seed(0) | ||
x = torch.randn(M, N, device='cuda') | ||
y = torch.zeros_like(x, device='cuda') | ||
n_rows, n_cols = x.shape | ||
blk_size = triton.next_power_of_2(n_cols) | ||
g = torch.ones((1, N), device='cuda') | ||
y_triton = triton_rmsnorm(x, g) | ||
y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size) | ||
|
||
y_torch = torch_rmsnorm(x, g) | ||
|
||
|
@@ -157,13 +161,16 @@ def run_benchmark(args): | |
@triton.testing.perf_report(config) | ||
def benchmark(M, N, provider): | ||
x = torch.randn(M, N, device='cuda', dtype=dtype) | ||
y = torch.zeros_like(x, device='cuda') | ||
n_rows, n_cols = x.shape | ||
blk_size = triton.next_power_of_2(n_cols) | ||
stream = torch.cuda.Stream() | ||
torch.cuda.set_stream(stream) | ||
g = torch.ones((1, N), device='cuda') | ||
if provider == 'torch': | ||
ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g)) | ||
if provider == 'triton': | ||
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, g)) | ||
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)) | ||
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) | ||
return gbps(ms) | ||
|
||
|
@@ -194,9 +201,10 @@ def parse_args(): | |
def main(): | ||
args = parse_args() | ||
if args.no_benchmark: | ||
x = torch.randn(args.M_start, args.N_start) | ||
x = torch.randn(args.M_start, args.N_start, device='cuda') | ||
y = torch.zeros_like(x, device='cuda') | ||
g = torch.ones((1, args.N_start), device='cuda') | ||
triton_rmsnorm(x, g) | ||
triton_rmsnorm(x, y, g) | ||
else: | ||
run_benchmark(args) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be simplified as if you want:
[triton.Config({'waves_per_eu': we}, num_warps=nw, num_stages=2) for (we, nw) in itertools.product([0, 1, 2, 4], [4, 8, 16])]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done