Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rmsnorm optimization for M = 1 #668

Merged
merged 4 commits into from
Dec 2, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions python/perf-kernels/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ def get_cuda_autotune_config():

def get_hip_autotune_config():
return [
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=1),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=1),
triton.Config({'waves_per_eu': 0}, num_warps=4, num_stages=2),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified as if you want:
[triton.Config({'waves_per_eu': we}, num_warps=nw, num_stages=2) for (we, nw) in itertools.product([0, 1, 2, 4], [4, 8, 16])]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

triton.Config({'waves_per_eu': 0}, num_warps=8, num_stages=2),
triton.Config({'waves_per_eu': 0}, num_warps=16, num_stages=2),
triton.Config({'waves_per_eu': 1}, num_warps=4, num_stages=2),
triton.Config({'waves_per_eu': 1}, num_warps=8, num_stages=2),
triton.Config({'waves_per_eu': 1}, num_warps=16, num_stages=2),
triton.Config({'waves_per_eu': 2}, num_warps=4, num_stages=2),
triton.Config({'waves_per_eu': 2}, num_warps=8, num_stages=2),
triton.Config({'waves_per_eu': 2}, num_warps=16, num_stages=2),
triton.Config({'waves_per_eu': 4}, num_warps=4, num_stages=2),
triton.Config({'waves_per_eu': 4}, num_warps=8, num_stages=2),
triton.Config({'waves_per_eu': 4}, num_warps=16, num_stages=2),
]


Expand All @@ -47,12 +50,13 @@ def get_autotune_config():
@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride, n_rows, n_cols, epsilon,
BLOCK_SIZE: tl.constexpr):
BLOCK_SIZE: tl.constexpr, NUM_PRGMS: tl.constexpr):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
for row_idx in tl.range(row_start, n_rows, row_step):
tl.assume(input_row_stride >= 0)
tl.assume(output_row_stride >= 0)
for row_idx in tl.range(row_start, n_rows, NUM_PRGMS):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the line g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0) before the loop, so we do not have to load it for each loop? Or the compiler can do that for us?

Copy link
Member Author

@xiaohuguo2023 xiaohuguo2023 Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why, but put in the loop gives slightly improved perf in average

|      N   | Triton (Old) | Triton (New) | Improvement (%) |
|----------|--------------|--------------|-----------------|
|   8192.0 |      2.822   |     3.035    |      7.55       |
|   9216.0 |      3.776   |     4.326    |     14.57       |
|  10240.0 |      4.165   |     4.734    |     13.67       |
|  11264.0 |      4.599   |     5.690    |     23.71       |
|  12288.0 |      5.235   |     5.265    |      0.57       |
|  13312.0 |      5.541   |     5.952    |      7.41       |
|  14336.0 |      6.304   |     5.941    |     -5.77       |
|  15360.0 |      7.544   |     7.380    |     -2.18       |
|  16384.0 |      7.069   |     7.664    |      8.43       |
|  17408.0 |      7.652   |     8.269    |      8.07       |
|  18432.0 |      8.110   |     8.330    |      2.71       |
|  19456.0 |      8.712   |     9.441    |      8.37       |
|  20480.0 |      8.915   |     9.488    |      6.43       |
|  21504.0 |     10.047   |    10.324    |      2.76       |
|  22528.0 |      9.858   |    10.207    |      3.54       |
|  23552.0 |     10.062   |     9.712    |     -3.48       |
|  24576.0 |     11.465   |    10.408    |     -9.23       |
|  25600.0 |     10.968   |    10.732    |     -2.15       |
|  26624.0 |     12.666   |    10.422    |    -17.72       |
|  27648.0 |     11.786   |    12.288    |      4.26       |
|  28672.0 |     13.457   |    11.591    |    -13.89       |
|  29696.0 |     12.321   |    13.150    |      6.73       |
|  30720.0 |     12.698   |    13.575    |      6.91       |
|  31744.0 |     14.271   |    15.549    |      8.95       |

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, really? if that is the case, we can put it in the loop.

row_start_ptr = input_ptr + row_idx * input_row_stride
input_ptrs = row_start_ptr + col_offsets
input_ptrs = tl.multiple_of(input_ptrs, (16, ))
Expand All @@ -72,15 +76,12 @@ def rms_kernel(output_ptr, input_ptr, g_ptr, input_row_stride, output_row_stride
tl.store(output_ptrs, rms_norm, mask=mask)


def triton_rmsnorm(x, g, epsilon=1e-6):
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)

y = torch.empty_like(x, device='cuda')
def triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size, epsilon=1e-6):
BLOCK_SIZE = blk_size

num_programs = n_rows
grid = lambda meta: (num_programs, )
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE)
NUM_PRGMS = n_rows
grid = lambda meta: (NUM_PRGMS, )
rms_kernel[grid](y, x, g, x.stride(0), y.stride(0), n_rows, n_cols, epsilon, BLOCK_SIZE, NUM_PRGMS)

return y

Expand All @@ -107,8 +108,11 @@ def torch_rmsnorm(x, g):
def test_rmsnorm(M, N):
torch.manual_seed(0)
x = torch.randn(M, N, device='cuda')
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)
g = torch.ones((1, N), device='cuda')
y_triton = triton_rmsnorm(x, g)
y_triton = triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)

y_torch = torch_rmsnorm(x, g)

Expand Down Expand Up @@ -157,13 +161,16 @@ def run_benchmark(args):
@triton.testing.perf_report(config)
def benchmark(M, N, provider):
x = torch.randn(M, N, device='cuda', dtype=dtype)
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)
stream = torch.cuda.Stream()
torch.cuda.set_stream(stream)
g = torch.ones((1, N), device='cuda')
if provider == 'torch':
ms = triton.testing.do_bench(lambda: torch_rmsnorm(x, g))
if provider == 'triton':
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, g))
ms = triton.testing.do_bench(lambda: triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size))
gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
return gbps(ms)

Expand Down Expand Up @@ -194,9 +201,12 @@ def parse_args():
def main():
args = parse_args()
if args.no_benchmark:
x = torch.randn(args.M_start, args.N_start)
x = torch.randn(args.M_start, args.N_start, device='cuda')
y = torch.zeros_like(x, device='cuda')
n_rows, n_cols = x.shape
blk_size = triton.next_power_of_2(n_cols)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaohuguo2023 Did you notice any big performance drop if blk_size >65k?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, with blk_size>65k, start to have vgpr spills

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my next PR will address this issue

g = torch.ones((1, args.N_start), device='cuda')
triton_rmsnorm(x, g)
triton_rmsnorm(x, y, g, n_rows, n_cols, blk_size)
else:
run_benchmark(args)

Expand Down
Loading