-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rmsnorm optimization for M = 1 #668
Conversation
for row_idx in tl.range(row_start, n_rows, row_step): | ||
tl.assume(input_row_stride >= 0) | ||
tl.assume(output_row_stride >= 0) | ||
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move the line g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0)
before the loop, so we do not have to load it for each loop? Or the compiler can do that for us?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why, but put in the loop gives slightly improved perf in average
| N | Triton (Old) | Triton (New) | Improvement (%) |
|----------|--------------|--------------|-----------------|
| 8192.0 | 2.822 | 3.035 | 7.55 |
| 9216.0 | 3.776 | 4.326 | 14.57 |
| 10240.0 | 4.165 | 4.734 | 13.67 |
| 11264.0 | 4.599 | 5.690 | 23.71 |
| 12288.0 | 5.235 | 5.265 | 0.57 |
| 13312.0 | 5.541 | 5.952 | 7.41 |
| 14336.0 | 6.304 | 5.941 | -5.77 |
| 15360.0 | 7.544 | 7.380 | -2.18 |
| 16384.0 | 7.069 | 7.664 | 8.43 |
| 17408.0 | 7.652 | 8.269 | 8.07 |
| 18432.0 | 8.110 | 8.330 | 2.71 |
| 19456.0 | 8.712 | 9.441 | 8.37 |
| 20480.0 | 8.915 | 9.488 | 6.43 |
| 21504.0 | 10.047 | 10.324 | 2.76 |
| 22528.0 | 9.858 | 10.207 | 3.54 |
| 23552.0 | 10.062 | 9.712 | -3.48 |
| 24576.0 | 11.465 | 10.408 | -9.23 |
| 25600.0 | 10.968 | 10.732 | -2.15 |
| 26624.0 | 12.666 | 10.422 | -17.72 |
| 27648.0 | 11.786 | 12.288 | 4.26 |
| 28672.0 | 13.457 | 11.591 | -13.89 |
| 29696.0 | 12.321 | 13.150 | 6.73 |
| 30720.0 | 12.698 | 13.575 | 6.91 |
| 31744.0 | 14.271 | 15.549 | 8.95 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, really? if that is the case, we can put it in the loop.
python/perf-kernels/rmsnorm.py
Outdated
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1), | ||
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1), | ||
triton.Config({'waves_per_eu': 0}, num_warps=4, num_stages=2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be simplified as if you want:
[triton.Config({'waves_per_eu': we}, num_warps=nw, num_stages=2) for (we, nw) in itertools.product([0, 1, 2, 4], [4, 8, 16])]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Can you add perf before / after this PR to the description? |
x = torch.randn(args.M_start, args.N_start, device='cuda') | ||
y = torch.zeros_like(x, device='cuda') | ||
n_rows, n_cols = x.shape | ||
blk_size = triton.next_power_of_2(n_cols) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xiaohuguo2023 Did you notice any big performance drop if blk_size >65k?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, with blk_size>65k, start to have vgpr spills
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my next PR will address this issue
rmsnorm kernel optimization