Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add blocked version to address performance issue of when N is large #672

Merged
merged 6 commits into from
Dec 6, 2024

Conversation

xiaohuguo2023
Copy link
Member

@xiaohuguo2023 xiaohuguo2023 commented Dec 4, 2024

blocked rmsnorm implementation

  • benefit when N >30K, the best BLOCK_SIZE = 65536
  • add unit test for N > 30K
  • add the logic in the wrapper to set the best BLOCK_SIZE
  • for non block version, this pr won't have any impart of performance.
  • non-blocked v.s. blocked version perf comparison
N Triton (non blocked) Triton (blocked) Improvement (%)
31744.0 11.736 11.615 -1.03
41984.0 7.326 11.911 62.59
52224.0 8.381 13.114 56.42
62464.0 9.721 15.099 55.30
72704.0 1.090 14.699 1249.45
82944.0 1.272 15.360 1107.08
93184.0 1.431 16.568 1058.33
103424.0 1.614 16.212 904.21
113664.0 1.766 18.606 953.76
123904.0 1.896 18.398 870.12
134144.0 1.284 17.659 1275.78
144384.0 1.531 17.966 1073.66
154624.0 1.636 19.041 1063.99
164864.0 1.578 18.224 1054.63
175104.0 1.852 19.307 942.62
185344.0 1.951 19.420 895.32
195584.0 1.894 19.194 914.01
205824.0 1.953 19.811 914.43
216064.0 2.071 19.505 841.59
226304.0 2.163 19.914 820.58
236544.0 2.234 18.855 743.71
246784.0 2.557 21.409 737.61

def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6):
BLOCK_SIZE = blk_size
# Use blocked approach if BLOCK_SIZE > 65536
USE_BLOCKED = BLOCK_SIZE > 31743
Copy link

@rahulbatra85 rahulbatra85 Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that I noticed in the layernorm tutorial was also using the dtype to determine whether to use blocked or non-blocked. I think what matters is not the actual number of elements, but the total size of elements.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, line 131 should be

if n_cols > 65535

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revised as suggested.

@vgokhale
Copy link
Collaborator

vgokhale commented Dec 4, 2024

Can you please add performance comparison with and without this change?

@xiaohuguo2023
Copy link
Member Author

Can you please add performance comparison with and without this change?

Can you please add performance comparison with and without this change?

see above

mask = col_offsets < n_cols
tl.assume(input_row_stride >= 0)
tl.assume(output_row_stride >= 0)
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this loop needed? NUM_PRGMS = n_rows in the caller. So it will never execute more than once?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, if N is small, and M > 304, we will need this persistent loop

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In persistent kernel the grid is sized according to the number of CUs.

On line 134 and 135, the grid is set as number of rows. It is agnostic to CUs.

How is this kernel persistent?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are correct, this is a fake persistent loop, I will remove it, and to get it work for persistent, I need a new PR for this

@@ -153,7 +217,8 @@ def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=dtype)
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)
MAX_FUSED_SIZE = 65536 // x.element_size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments on this magic number?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, just saw your messages, anything great than MAX_FUSED_SIZE, we start to have spills

@vgokhale vgokhale self-requested a review December 6, 2024 22:16
@vgokhale
Copy link
Collaborator

vgokhale commented Dec 6, 2024

LGTM after outstanding comments are addressed.

input_ptrs = tl.multiple_of(input_ptrs, (16, ))
g_ptrs = g_ptr + cols
output_ptrs = row_output_ptr + cols
x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tried peeling the last iter? is it worth trying? Can you add a TODO to try that as part of your next PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the loop peeling only applied for blocked version, not this one. sure, I will try that

@xiaohuguo2023 xiaohuguo2023 merged commit 736071f into main_perf Dec 6, 2024
4 checks passed
@xiaohuguo2023 xiaohuguo2023 deleted the rmsnorm_v2 branch December 6, 2024 22:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants