-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add blocked version to address performance issue of when N is large #672
Conversation
python/perf-kernels/rmsnorm.py
Outdated
def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6): | ||
BLOCK_SIZE = blk_size | ||
# Use blocked approach if BLOCK_SIZE > 65536 | ||
USE_BLOCKED = BLOCK_SIZE > 31743 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing that I noticed in the layernorm tutorial was also using the dtype to determine whether to use blocked or non-blocked. I think what matters is not the actual number of elements, but the total size of elements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right, line 131 should be
if n_cols > 65535
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I was suggesting more like this
https://github.com/ROCm/triton/blob/main_perf/python/perf-kernels/layernorm.py#L118
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revised as suggested.
Can you please add performance comparison with and without this change? |
see above |
python/perf-kernels/rmsnorm.py
Outdated
mask = col_offsets < n_cols | ||
tl.assume(input_row_stride >= 0) | ||
tl.assume(output_row_stride >= 0) | ||
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this loop needed? NUM_PRGMS = n_rows in the caller. So it will never execute more than once?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, if N is small, and M > 304, we will need this persistent loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In persistent kernel the grid is sized according to the number of CUs.
On line 134 and 135, the grid is set as number of rows. It is agnostic to CUs.
How is this kernel persistent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are correct, this is a fake persistent loop, I will remove it, and to get it work for persistent, I need a new PR for this
@@ -153,7 +217,8 @@ def benchmark(M, N, provider): | |||
x = torch.randn(M, N, device='cuda', dtype=dtype) | |||
y = torch.zeros_like(x, device='cuda') | |||
n_rows, n_cols = x.shape | |||
blk_size = triton.next_power_of_2(n_cols) | |||
MAX_FUSED_SIZE = 65536 // x.element_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some comments on this magic number?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, just saw your messages, anything great than MAX_FUSED_SIZE, we start to have spills
LGTM after outstanding comments are addressed. |
input_ptrs = tl.multiple_of(input_ptrs, (16, )) | ||
g_ptrs = g_ptr + cols | ||
output_ptrs = row_output_ptr + cols | ||
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tried peeling the last iter? is it worth trying? Can you add a TODO to try that as part of your next PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the loop peeling only applied for blocked version, not this one. sure, I will try that
blocked rmsnorm implementation