Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamk v0.3 #660

Merged
merged 14 commits into from
Dec 4, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import random

from streamk_kernel import streamk_gemm
#from streamk_kernel_atomic import streamk_gemm
#from persistent_gemm import streamk_gemm

torch.manual_seed(123)
random.seed(123)
Expand Down Expand Up @@ -95,6 +97,7 @@ def _call(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.Tensor,
BLOCK_SIZE_K=BLK_K,
GROUP_SIZE_M=gsize_m,
NUM_SMS=total_programs_streamk,
STREAMK_TILES=total_tiles_streamk,
NUM_XCDS=num_xcds,
BIAS=use_bias,
EVEN_K=even_k,
Expand Down Expand Up @@ -135,7 +138,7 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.

## test for tiles that is not multipe of 304 tiles
#m, n, k = 4096, 4096, 8192 # some problem size to test
#m, n, k = 8192, 8192, 8192 # some problem size to test
m, n, k = 8192, 8192, 8192 # some problem size to test
#m, n, k = 512, 512, 512 # some problem size to test

## memory bound sizes
Expand All @@ -148,7 +151,7 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch.
#m, n, k = 5632, 6656, 7936

## test when k is not multiple of 16
m, n, k = 4864, 4096, 4300
#m, n, k = 4864, 4096, 4300

A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(n, k, device="cuda", dtype=torch.float16).T
Expand Down
40 changes: 40 additions & 0 deletions python/perf-kernels/streamk/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,43 @@
# streamk gemm script v0.3

### features added:

- new persistent gemm kernel
- gemm benchmark tool using nearest neighbour approach.

### benchmark commandline

```
python gemm_benchmark.py
```

# streamk gemm script v0.2
neoblizz marked this conversation as resolved.
Show resolved Hide resolved

### features added:

- new streamk tuning script to reduce compiling and profiling time

- use load/store cache modifier to reimplement spinning lock

- add CI test for streamk-kernel

### potential issues:

- there may be hanging issue when use random grid sizes
- large register spills when using tile size 256x256x64

### tuning command

```
python tune_streamk.py --gemm_size_file input_nn_size.yaml --ngpus 8 --jobs 24
```

### calculate occ

```
../tools/occ.sh "python tune_streamk.py --gemm_size_file single_item.yaml --compare_wo_tuning"
```

# streamk gemm script v0.1

The plan is to use this version as the base version for the future triton streamk gemm development.
Expand Down
103 changes: 103 additions & 0 deletions python/perf-kernels/streamk/gemm_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os
import json
import torch
import triton
import numpy as np

from utils.solution_selection import tunedtree, tunedarr, solution_params
from utils.gemm_wrapper import matmul


def selection_test():
sm = 304
neoblizz marked this conversation as resolved.
Show resolved Hide resolved
# max_m, max_n, max_k = 8193, 8193, 8193
# A_max = torch.randn(max_m, max_k, device="cuda", dtype=torch.float16)
# B_max = torch.randn(max_n, max_k, device="cuda", dtype=torch.float16)
# C_max = torch.zeros(max_m, max_n, device="cuda", dtype=torch.float16)
# locks_max = torch.zeros((sm,), device="cuda", dtype=torch.int32)
# P_max = torch.zeros((sm, 256*256), device="cuda", dtype=torch.float32)
# bias_max = torch.zeros((max_m,), device="cuda", dtype=torch.float16)

# Temporary tensors with the maximum size
# A = torch.empty(max_m, max_k, device="cuda", dtype=torch.float16)
# B = torch.empty(max_n, max_k, device="cuda", dtype=torch.float16)
# output = torch.empty(max_m, max_n, device="cuda", dtype=torch.float16)
# locks = torch.empty((sm,), device="cuda", dtype=torch.int32)
# P = torch.empty((sm, 256*256), device="cuda", dtype=torch.float32)
# bias = torch.empty((max_m,), device="cuda", dtype=torch.float16)

# Remove existing benchmark file if it exists
if os.path.exists('benchmark.json'):
os.remove('benchmark.json')

with open("benchmark.json", "a") as f:
for m in range(128, 8193, 250):
for n in range(128, 8193, 250):
for k in range(128, 8193, 250):
print(m, n, k)
# Point A and B to the appropriate slices of A_max and B_max
# A.set_(A_max.storage(), 0, (m, k))
# B.set_(B_max.storage(), 0, (n, k)).T
A = torch.randn(m, k, device="cuda", dtype=torch.float16)
B = torch.randn(n, k, device="cuda", dtype=torch.float16).T

expected = A @ B
pytorch_ms = triton.testing.do_bench(lambda: A @ B)

dist, treeidx = tunedtree.query(np.array([m, n, k]).reshape(1, -1))
print(f"{dist}")
mt = solution_params[tunedarr[treeidx[0][0]]]
print(f"{mt}")
BLK_M = mt['BLOCK_SIZE_M']
BLK_N = mt['BLOCK_SIZE_N']
BLK_K = mt['BLOCK_SIZE_K']
gsize_m = mt['GROUP_SIZE_M']
two_tiles = 'True'
num_stages = mt['num_stages']
num_warps = mt['num_warps']
waves_per_eu = mt['waves_per_eu']
mfmaInstrSize = mt['matrix_instr_nonkdim']
kpack = mt['kpack']
total_blocks_M = triton.cdiv(m, BLK_M)
total_blocks_N = triton.cdiv(n, BLK_N)
total_tiles = total_blocks_M * total_blocks_N

# output.set_(C_max.storage(), 0, (m, n))
# locks.set_(locks_max.storage(), 0, (sm,))
# P.set_(P_max.storage(), 0, (sm, BLK_M*BLK_N))
# bias.set_(bias_max.storage(), 0, (m,))
output = torch.zeros(m, n, device="cuda", dtype=torch.float16)
locks = torch.zeros((sm, ), device="cuda", dtype=torch.int32)
P = torch.zeros((sm, BLK_M * BLK_N), device="cuda", dtype=torch.float32)
bias = torch.zeros((m, ), device="cuda", dtype=torch.float16)
triton_ms = triton.testing.do_bench(
lambda: matmul.apply(A, B, output, bias, P, locks, sm, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles,
num_stages, num_warps, waves_per_eu, mfmaInstrSize, kpack))

max_disc = 0.0
# large tolerance to accommodate for large K (rounding due to half precision)
assert max_disc <= 5., (
f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}"
)
info = {
"m": m,
"n": n,
"k": k,
"MT0": BLK_M,
"MT1": BLK_N,
"DepU": BLK_K,
"sm": sm,
"GROUP_SIZE_M": gsize_m,
"total_tiles": total_tiles,
"num_warps": num_warps,
"mfmaInstrSize": mfmaInstrSize,
"kpack": kpack,
"disc": max_disc,
"triton_ms": triton_ms,
"pytorch_ms": pytorch_ms,
}
json.dump(info, f)
f.write('\n')


selection_test()
96 changes: 96 additions & 0 deletions python/perf-kernels/streamk/persistent_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import triton
import triton.language as tl


@triton.jit()
def streamk_gemm(
A,
B,
C,
bias_ptr,
P,
locks,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_bias,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
STREAMK_TILES: tl.constexpr,
NUM_XCDS: tl.constexpr,
BIAS: tl.constexpr,
EVEN_K: tl.constexpr,
):
pid = tl.program_id(0)
if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
total_tiles = num_pid_m * num_pid_n

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)

acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32
for tile_id in range(pid, total_tiles, NUM_SMS):
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
rk = tl.arange(0, BLOCK_SIZE_K)
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
tl.assume(pid_m > 0)
tl.assume(pid_n > 0)

loop_k = tl.cdiv(K, BLOCK_SIZE_K)
if not EVEN_K:
loop_k -= 1

acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, loop_k):
a = tl.load(tl.multiple_of(A_BASE, (1, 16)))
b = tl.load(tl.multiple_of(B_BASE, (16, 1)))
acc += tl.dot(a, b)
A_BASE += BLOCK_SIZE_K * stride_ak
B_BASE += BLOCK_SIZE_K * stride_bk

if not EVEN_K:
k = loop_k
rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
A_BASE = A + rm[:, None] * stride_am + rk[None, :] * stride_ak
B_BASE = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn
A_BASE = tl.multiple_of(A_BASE, (1, 16))
B_BASE = tl.multiple_of(B_BASE, (16, 1))
a = tl.load(A_BASE, mask=rk[None, :] < K, other=0.0)
b = tl.load(B_BASE, mask=rk[:, None] < K, other=0.0)
acc += tl.dot(a, b)

c = acc.to(C.type.element_ty)
rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
c_mask = (rm[:, None] < M) & (rn[None, :] < N)
C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
tl.store(C_, c, c_mask)
41 changes: 25 additions & 16 deletions python/perf-kernels/streamk/streamk_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,26 @@ def streamk_gemm(
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
STREAMK_TILES: tl.constexpr,
NUM_XCDS: tl.constexpr,
BIAS: tl.constexpr,
EVEN_K: tl.constexpr,
):
pid = tl.program_id(0)
if NUM_XCDS != 1:
pid = (pid % NUM_XCDS) * (NUM_SMS // NUM_XCDS) + (pid // NUM_XCDS)

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
iters_per_tile = tl.cdiv(K, BLOCK_SIZE_K)
total_tiles = num_pid_m * num_pid_n
if NUM_SMS > 0 and total_tiles > NUM_SMS:
total_streamk_tiles = total_tiles % NUM_SMS
# total_streamk_tiles = total_streamk_tiles + NUM_SMS
total_full_tiles = total_tiles - total_streamk_tiles
total_streamk_iters = total_streamk_tiles * iters_per_tile
streamk_iters_pcu = total_streamk_iters // NUM_SMS
streamk_remainder_iters = total_streamk_iters % NUM_SMS
else:
total_full_tiles = total_tiles
total_streamk_tiles = 0
streamk_iters_pcu = 0
streamk_remainder_iters = 0
total_streamk_iters = 0
total_full_tiles = total_tiles - STREAMK_TILES

tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)

acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32

Expand All @@ -60,6 +55,8 @@ def streamk_gemm(
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
tl.assume(pid_m > 0)
tl.assume(pid_n > 0)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down Expand Up @@ -100,12 +97,18 @@ def streamk_gemm(
if BIAS:
c += bias[:, None]

rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)
C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn
mask = (rm < M)[:, None] & (rn < N)[None, :]
tl.store(C_, c, mask=mask)

tl.assume(pid >= 0)
total_streamk_iters = STREAMK_TILES * iters_per_tile
streamk_iters_pcu = total_streamk_iters // NUM_SMS
streamk_remainder_iters = total_streamk_iters % NUM_SMS
start_iter = total_full_tiles * iters_per_tile + pid * streamk_iters_pcu + tl.minimum(pid, streamk_remainder_iters)
last_iter = total_full_tiles * iters_per_tile + (pid + 1) * streamk_iters_pcu + tl.minimum(
pid + 1, streamk_remainder_iters)
Expand All @@ -119,6 +122,8 @@ def streamk_gemm(
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
tl.assume(pid_m > 0)
tl.assume(pid_n > 0)

rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
Expand Down Expand Up @@ -158,11 +163,15 @@ def streamk_gemm(
P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :]
tl.store(P_, acc, cache_modifier=".wt")
tl.store(locks + pid, 1, cache_modifier=".wt")
# tl.store(P_, acc)
# tl.debug_barrier()
# tl.atomic_xchg(locks + pid, 1)
else:
next_pid = pid + 1
tile_iter_end = tile_iter + iters_per_tile
end = end_iter
while (end < tile_iter_end and next_pid < NUM_SMS):
# while tl.atomic_cas(locks + next_pid, 1, 1) != 1:
while tl.load(locks + next_pid, cache_modifier=".cv", volatile=True) != 1:
pass
rm1 = tl.arange(0, BLOCK_SIZE_M)
Expand Down
Loading
Loading