Skip to content

Commit

Permalink
Add tests for pipelined descriptor creation
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Dec 2, 2024
1 parent 7854d9b commit 8243c50
Show file tree
Hide file tree
Showing 6 changed files with 592 additions and 2 deletions.
104 changes: 104 additions & 0 deletions python/test/unit/hopper/test_experimental_tma.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,107 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
)
torch.testing.assert_close(ref_out, A)
assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"]


@triton.jit
def batched_gemm_kernel(a_ptr, b_ptr, c_ptr, #
B, M, N, K, #
dtype: tl.constexpr, #
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
NUM_SMS: tl.constexpr):
start_pid = tl.program_id(axis=0)
num_tiles_m = tl.cdiv(M, BLOCK_M)
num_tiles_n = tl.cdiv(N, BLOCK_N)
k_tiles = tl.cdiv(K, BLOCK_K)
num_tiles_per_batch = num_tiles_m * num_tiles_n
num_tiles = B * num_tiles_per_batch

tiles_per_SM = num_tiles // NUM_SMS
if start_pid < num_tiles % NUM_SMS:
tiles_per_SM += 1

tile_id = start_pid - NUM_SMS
ki = -1

tile_m = 0
tile_n = 0
tile_b = 0

offs_m = 0
offs_n = 0
offs_b = 0

a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], [BLOCK_M, BLOCK_K])
b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], [BLOCK_N, BLOCK_K])
c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], [BLOCK_M, BLOCK_N])

accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

for _ in range(k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
tile_id += NUM_SMS
tile_b = tile_id // num_tiles_per_batch
tile_m = (tile_id // num_tiles_n) % num_tiles_m
tile_n = tile_id % num_tiles_n

offs_b = tile_b
offs_m = tile_m * BLOCK_M
offs_n = tile_n * BLOCK_N

a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1],
[BLOCK_M, BLOCK_K])
b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1],
[BLOCK_N, BLOCK_K])
c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1],
[BLOCK_M, BLOCK_N])

offs_k = ki * BLOCK_K

a = a_desc.load([offs_m, offs_k])
b = b_desc.load([offs_n, offs_k])
accumulator = tl.dot(a, b.T, accumulator)

if ki == k_tiles - 1:
c = accumulator.to(dtype)

c_desc.store([offs_m, offs_n], c)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)


@requires_tma
def test_tensor_descriptor_batched_gemm():
device = "cuda"
B, M, N, K = 2, 1024, 1024, 128
BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64
NUM_SMS = 96
num_stages = 3

grid = (min(NUM_SMS, B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)), )

a = torch.randn((B, M, K), device=device, dtype=torch.float16)
b = torch.randn((B, N, K), device=device, dtype=torch.float16)
c = torch.empty((B, M, N), device=device, dtype=torch.float16)

expect = torch.bmm(a, b.mT)

def alloc_fn(size: int, align: int, stream: Optional[int]):
# TODO: should only need num_stages * 3 descriptors per SM
assert size == 128 * 3 * (num_stages + 1) * grid[0]
assert align == 128
assert stream == 0
return torch.empty(size, dtype=torch.int8, device="cuda")

triton.set_allocator(alloc_fn)

h = batched_gemm_kernel[grid](
a, b, c, #
B, M, N, K, #
tl.float16, #
BLOCK_M, BLOCK_N, BLOCK_K, #
NUM_SMS, #
num_stages=num_stages, num_warps=8)
print(h.n_regs)
torch.cuda.synchronize()

torch.testing.assert_close(c, expect, rtol=1e-3, atol=1e-3)
5 changes: 5 additions & 0 deletions test/TritonGPU/samples/descriptor-matmul-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
// CHECK: #[[$ATTR_2:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
// CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}>
// CHECK: #[[$ATTR_4:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>
// To regenerate this test case, run the command
// triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in \
// -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir
// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s

// CHECK-LABEL: tt.func public @matmul_kernel_with_descriptors(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
// CHECK: %[[VAL_6:.*]] = arith.constant 2 : i32
Expand Down
3 changes: 2 additions & 1 deletion test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// To regenerate this test case, run the command
// triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \
// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="module" \
// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in \
// -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir
// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s

module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @matmul_kernel_with_descriptors(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c8_i32 = arith.constant 8 : i32
Expand Down
Loading

0 comments on commit 8243c50

Please sign in to comment.