Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rmsnorm optimization for M = 1 #668

Merged
merged 4 commits into from
Dec 2, 2024
Merged

rmsnorm optimization for M = 1 #668

merged 4 commits into from
Dec 2, 2024

Conversation

xiaohuguo2023
Copy link
Member

@xiaohuguo2023 xiaohuguo2023 commented Nov 27, 2024

rmsnorm kernel optimization

  • enable buffer load/store
  • change grid size as tl.constexpr
  • add autotuning configs with waves_per_eu = 0
  • move memory allocation outside of the wrapper to reduce autotuning overheads
  • fix issue for no_benchmark case.
  • it has average 88% performance improvement over the base version.

for row_idx in tl.range(row_start, n_rows, row_step):
tl.assume(input_row_stride >= 0)
tl.assume(output_row_stride >= 0)
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the line g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0) before the loop, so we do not have to load it for each loop? Or the compiler can do that for us?

Copy link
Member Author

@xiaohuguo2023 xiaohuguo2023 Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why, but put in the loop gives slightly improved perf in average

|      N   | Triton (Old) | Triton (New) | Improvement (%) |
|----------|--------------|--------------|-----------------|
|   8192.0 |      2.822   |     3.035    |      7.55       |
|   9216.0 |      3.776   |     4.326    |     14.57       |
|  10240.0 |      4.165   |     4.734    |     13.67       |
|  11264.0 |      4.599   |     5.690    |     23.71       |
|  12288.0 |      5.235   |     5.265    |      0.57       |
|  13312.0 |      5.541   |     5.952    |      7.41       |
|  14336.0 |      6.304   |     5.941    |     -5.77       |
|  15360.0 |      7.544   |     7.380    |     -2.18       |
|  16384.0 |      7.069   |     7.664    |      8.43       |
|  17408.0 |      7.652   |     8.269    |      8.07       |
|  18432.0 |      8.110   |     8.330    |      2.71       |
|  19456.0 |      8.712   |     9.441    |      8.37       |
|  20480.0 |      8.915   |     9.488    |      6.43       |
|  21504.0 |     10.047   |    10.324    |      2.76       |
|  22528.0 |      9.858   |    10.207    |      3.54       |
|  23552.0 |     10.062   |     9.712    |     -3.48       |
|  24576.0 |     11.465   |    10.408    |     -9.23       |
|  25600.0 |     10.968   |    10.732    |     -2.15       |
|  26624.0 |     12.666   |    10.422    |    -17.72       |
|  27648.0 |     11.786   |    12.288    |      4.26       |
|  28672.0 |     13.457   |    11.591    |    -13.89       |
|  29696.0 |     12.321   |    13.150    |      6.73       |
|  30720.0 |     12.698   |    13.575    |      6.91       |
|  31744.0 |     14.271   |    15.549    |      8.95       |

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, really? if that is the case, we can put it in the loop.

triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 0}, num_warps=4, num_stages=2),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified as if you want:
[triton.Config({'waves_per_eu': we}, num_warps=nw, num_stages=2) for (we, nw) in itertools.product([0, 1, 2, 4], [4, 8, 16])]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@vgokhale
Copy link
Collaborator

vgokhale commented Dec 2, 2024

Can you add perf before / after this PR to the description?

x = torch.randn(args.M_start, args.N_start, device='cuda')
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaohuguo2023 Did you notice any big performance drop if blk_size >65k?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, with blk_size>65k, start to have vgpr spills

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my next PR will address this issue

@vgokhale vgokhale merged commit fc558e7 into main_perf Dec 2, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants