diff --git a/python/test/unit/hopper/test_experimental_tma.py b/python/test/unit/hopper/test_experimental_tma.py index 23065953d65b..38e2da747504 100644 --- a/python/test/unit/hopper/test_experimental_tma.py +++ b/python/test/unit/hopper/test_experimental_tma.py @@ -538,3 +538,108 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): ) torch.testing.assert_close(ref_out, A) assert "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned" in kernel.asm["ptx"] + + +@triton.jit +def batched_gemm_kernel(a_ptr, b_ptr, c_ptr, # + B, M, N, K, # + dtype: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SMS: tl.constexpr): + start_pid = tl.program_id(axis=0) + num_tiles_m = tl.cdiv(M, BLOCK_M) + num_tiles_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles_per_batch = num_tiles_m * num_tiles_n + num_tiles = B * num_tiles_per_batch + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + tile_m = 0 + tile_n = 0 + tile_b = 0 + + offs_m = 0 + offs_n = 0 + offs_b = 0 + + a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], [BLOCK_M, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], [BLOCK_N, BLOCK_K]) + c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], [BLOCK_M, BLOCK_N]) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + for _ in range(k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + tile_b = tile_id // num_tiles_per_batch + tile_m = (tile_id // num_tiles_n) % num_tiles_m + tile_n = tile_id % num_tiles_n + + offs_b = tile_b + offs_m = tile_m * BLOCK_M + offs_n = tile_n * BLOCK_N + + a_desc = tl._experimental_make_tensor_descriptor(a_ptr + offs_b * (M * K), [M, K], [K, 1], + [BLOCK_M, BLOCK_K]) + b_desc = tl._experimental_make_tensor_descriptor(b_ptr + offs_b * (N * K), [N, K], [K, 1], + [BLOCK_N, BLOCK_K]) + c_desc = tl._experimental_make_tensor_descriptor(c_ptr + offs_b * (M * N), [M, N], [N, 1], + [BLOCK_M, BLOCK_N]) + + offs_k = ki * BLOCK_K + + a = a_desc.load([offs_m, offs_k]) + b = b_desc.load([offs_n, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + + if ki == k_tiles - 1: + c = accumulator.to(dtype) + + c_desc.store([offs_m, offs_n], c) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + +@requires_tma +def test_tensor_descriptor_batched_gemm(): + device = "cuda" + B, M, N, K = 2, 1024, 1024, 128 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64 + NUM_SMS = 96 + num_stages = 3 + + M, N, K = BLOCK_M, BLOCK_N, BLOCK_K + + a = torch.randn((B, M, K), device=device, dtype=torch.float16) + b = torch.randn((B, N, K), device=device, dtype=torch.float16) + c = torch.empty((B, M, N), device=device, dtype=torch.float16) + + expect = torch.bmm(a, b.mT) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + # TODO: should only need num_stages * 3 descriptors per SM + # assert size == 128 * 3 * (num_stages + 1) * NUM_SMS + assert align == 128 + assert stream == 0 + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + grid = lambda META: (min(NUM_SMS, B * triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])), ) + h = batched_gemm_kernel[grid]( + a, b, c, # + B, M, N, K, # + tl.float16, # + BLOCK_M, BLOCK_N, BLOCK_K, # + NUM_SMS, # + num_stages=num_stages, num_warps=8) + print(h.n_regs) + torch.cuda.synchronize() + + torch.testing.assert_close(c, expect, rtol=1e-3, atol=1e-3) diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir index 4bcff281b5c0..1f54674c35bc 100644 --- a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir @@ -10,7 +10,12 @@ // CHECK: #[[$ATTR_2:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> // CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> // CHECK: #[[$ATTR_4:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in \ +// -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir // RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s + // CHECK-LABEL: tt.func public @matmul_kernel_with_descriptors( // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { // CHECK: %[[VAL_6:.*]] = arith.constant 2 : i32 diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in index 5175c1a6cefa..69e75b91fcd7 100644 --- a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in @@ -1,8 +1,9 @@ // To regenerate this test case, run the command // triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ -// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="module" \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in \ // -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir // RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s + module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { tt.func public @matmul_kernel_with_descriptors(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { %c8_i32 = arith.constant 8 : i32 diff --git a/test/TritonGPU/samples/simulated-grouped-gemm.mlir b/test/TritonGPU/samples/simulated-grouped-gemm.mlir new file mode 100644 index 000000000000..792136758365 --- /dev/null +++ b/test/TritonGPU/samples/simulated-grouped-gemm.mlir @@ -0,0 +1,376 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. + +// CHECK: #[[$ATTR_0:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +// CHECK: #[[$ATTR_2:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in \ +// -o test/TritonGPU/samples/simulated-grouped-gemm.mlir +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +// CHECK-LABEL: tt.func public @matmul_kernel_descriptor_persistent( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_6:.*]] = arith.constant 4 : i64 +// CHECK: %[[VAL_7:.*]] = arith.constant 2 : i64 +// CHECK: %[[VAL_8:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_9:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_10:.*]] = arith.constant false +// CHECK: %[[VAL_11:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_12:.*]] = arith.constant 132 : i32 +// CHECK: %[[VAL_13:.*]] = arith.constant -1 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_15:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_16:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_17:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_18:.*]] = arith.constant 64 : i32 +// CHECK: %[[VAL_19:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_20:.*]] = arith.constant 127 : i32 +// CHECK: %[[VAL_21:.*]] = arith.constant 255 : i32 +// CHECK: %[[VAL_22:.*]] = arith.constant 63 : i32 +// CHECK: %[[VAL_23:.*]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_24:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_3]], %[[VAL_20]] : i32 +// CHECK: %[[VAL_26:.*]] = arith.divsi %[[VAL_25]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_4]], %[[VAL_21]] : i32 +// CHECK: %[[VAL_28:.*]] = arith.divsi %[[VAL_27]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_29:.*]] = arith.addi %[[VAL_5]], %[[VAL_22]] : i32 +// CHECK: %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_31:.*]] = arith.muli %[[VAL_26]], %[[VAL_28]] : i32 +// CHECK: %[[VAL_32:.*]] = arith.extsi %[[VAL_5]] : i32 to i64 +// CHECK: %[[VAL_33:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_32]], %[[VAL_19]]] : , > +// CHECK: %[[VAL_34:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_32]], %[[VAL_19]]] : , > +// CHECK: %[[VAL_35:.*]] = arith.extsi %[[VAL_4]] : i32 to i64 +// CHECK: %[[VAL_36:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_35]], %[[VAL_19]]] : , > +// CHECK: %[[VAL_37:.*]] = arith.divsi %[[VAL_31]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_38:.*]] = arith.remsi %[[VAL_31]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_39:.*]] = arith.cmpi slt, %[[VAL_24]], %[[VAL_38]] : i32 +// CHECK: %[[VAL_40:.*]] = scf.if %[[VAL_39]] -> (i32) { +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_37]], %[[VAL_11]] : i32 +// CHECK: scf.yield %[[VAL_41]] : i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_37]] : i32 +// CHECK: } +// CHECK: %[[VAL_42:.*]] = arith.subi %[[VAL_24]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_43:.*]] = arith.muli %[[VAL_28]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_44:.*]] = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 +// CHECK: %[[VAL_45:.*]] = arith.muli %[[VAL_30]], %[[VAL_40]] : i32 +// CHECK: %[[VAL_46:.*]] = arith.subi %[[VAL_30]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_47:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_48:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_49:.*]] = ttg.global_scratch_alloc {alignment = 128 : i32, nbytes = 384 : i32} : !tt.ptr +// CHECK: %[[VAL_50:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_51:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_52:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_53:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_14]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.init_barrier %[[VAL_53]], 1 : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_54:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_11]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.init_barrier %[[VAL_54]], 1 : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_55:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_8]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.init_barrier %[[VAL_55]], 1 : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_56:.*]] = arith.cmpi sgt, %[[VAL_45]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_56]], %[[VAL_24]], %[[VAL_42]] : i32 +// CHECK: %[[VAL_58:.*]] = arith.select %[[VAL_56]], %[[VAL_14]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_59:.*]]:2 = scf.if %[[VAL_56]] -> (i32, i32) { +// CHECK: %[[VAL_60:.*]] = arith.divsi %[[VAL_24]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_61:.*]] = arith.muli %[[VAL_60]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_62:.*]] = arith.subi %[[VAL_26]], %[[VAL_61]] : i32 +// CHECK: %[[VAL_63:.*]] = arith.minsi %[[VAL_62]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_64:.*]] = arith.remsi %[[VAL_24]], %[[VAL_63]] : i32 +// CHECK: %[[VAL_65:.*]] = arith.addi %[[VAL_61]], %[[VAL_64]] : i32 +// CHECK: %[[VAL_66:.*]] = arith.remsi %[[VAL_24]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_67:.*]] = arith.divsi %[[VAL_66]], %[[VAL_63]] : i32 +// CHECK: %[[VAL_68:.*]] = arith.muli %[[VAL_65]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_69:.*]] = arith.muli %[[VAL_67]], %[[VAL_17]] : i32 +// CHECK: scf.yield %[[VAL_68]], %[[VAL_69]] : i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_14]], %[[VAL_14]] : i32, i32 +// CHECK: } +// CHECK: %[[VAL_70:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_14]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.barrier_expect %[[VAL_70]], 49152, %[[VAL_56]] : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_71:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_14]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_72:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_33]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_72]]{{\[}}%[[VAL_73:.*]]#0, %[[VAL_14]]] %[[VAL_71]], %[[VAL_70]], %[[VAL_56]] : , <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> <128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_74:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_14]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_75:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_34]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_75]]{{\[}}%[[VAL_73]]#1, %[[VAL_14]]] %[[VAL_74]], %[[VAL_70]], %[[VAL_56]] : , <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> <256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_76:.*]] = arith.cmpi sgt, %[[VAL_45]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_77:.*]] = arith.cmpi ne, %[[VAL_46]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_78:.*]] = arith.extui %[[VAL_77]] : i1 to i32 +// CHECK: %[[VAL_79:.*]] = arith.cmpi eq, %[[VAL_78]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_80:.*]] = arith.andi %[[VAL_76]], %[[VAL_79]] : i1 +// CHECK: %[[VAL_81:.*]]:10 = scf.if %[[VAL_80]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_82:.*]] = arith.addi %[[VAL_58]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_83:.*]] = arith.cmpi eq, %[[VAL_82]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_14]], %[[VAL_82]] : i32 +// CHECK: %[[VAL_85:.*]] = arith.extui %[[VAL_83]] : i1 to i32 +// CHECK: %[[VAL_86:.*]] = arith.extui %[[VAL_83]] : i1 to i32 +// CHECK: %[[VAL_87:.*]] = arith.extui %[[VAL_83]] : i1 to i32 +// CHECK: %[[VAL_88:.*]]:3 = scf.if %[[VAL_83]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>) { +// CHECK: %[[VAL_89:.*]] = tt.addptr %[[VAL_0]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_90:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_91:.*]] = arith.shrsi %[[VAL_90]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_47]], %[[VAL_89]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_91]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_47]] : !tt.ptr +// CHECK: %[[VAL_92:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_47]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_93:.*]] = tt.addptr %[[VAL_1]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_94:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_95:.*]] = arith.shrsi %[[VAL_94]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_48]], %[[VAL_93]], {{\[}}%[[VAL_18]], %[[VAL_17]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_95]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_48]] : !tt.ptr +// CHECK: %[[VAL_96:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_48]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_97:.*]] = tt.addptr %[[VAL_2]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_98:.*]] = arith.muli %[[VAL_35]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_99:.*]] = arith.shrsi %[[VAL_98]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_49]], %[[VAL_97]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_99]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_49]] : !tt.ptr +// CHECK: %[[VAL_100:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_49]] : !tt.ptr to !tt.tensordesc> +// CHECK: scf.yield %[[VAL_92]], %[[VAL_96]], %[[VAL_100]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc> +// CHECK: } else { +// CHECK: scf.yield %[[VAL_33]], %[[VAL_34]], %[[VAL_36]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc> +// CHECK: } +// CHECK: %[[VAL_101:.*]] = arith.addi %[[VAL_57]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_102:.*]] = arith.divsi %[[VAL_101]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_103:.*]] = arith.muli %[[VAL_102]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_104:.*]] = arith.subi %[[VAL_26]], %[[VAL_103]] : i32 +// CHECK: %[[VAL_105:.*]] = arith.minsi %[[VAL_104]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_106:.*]] = arith.remsi %[[VAL_101]], %[[VAL_105]] : i32 +// CHECK: %[[VAL_107:.*]] = arith.addi %[[VAL_103]], %[[VAL_106]] : i32 +// CHECK: %[[VAL_108:.*]] = arith.remsi %[[VAL_101]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_109:.*]] = arith.divsi %[[VAL_108]], %[[VAL_105]] : i32 +// CHECK: %[[VAL_110:.*]] = arith.muli %[[VAL_107]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_111:.*]] = arith.muli %[[VAL_109]], %[[VAL_17]] : i32 +// CHECK: scf.yield %[[VAL_112:.*]]#0, %[[VAL_112]]#1, %[[VAL_112]]#2, %[[VAL_101]], %[[VAL_84]], %[[VAL_110]], %[[VAL_111]], %[[VAL_85]], %[[VAL_86]], %[[VAL_87]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_33]], %[[VAL_34]], %[[VAL_36]], %[[VAL_57]], %[[VAL_58]], %[[VAL_73]]#0, %[[VAL_73]]#1, %[[VAL_14]], %[[VAL_14]], %[[VAL_14]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_113:.*]] = arith.muli %[[VAL_78]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_114:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_11]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.barrier_expect %[[VAL_114]], 49152, %[[VAL_76]] : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_115:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_11]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_116:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_117:.*]]#0 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_116]]{{\[}}%[[VAL_117]]#5, %[[VAL_113]]] %[[VAL_115]], %[[VAL_114]], %[[VAL_76]] : , <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> <128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_118:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_11]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_119:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_117]]#1 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_119]]{{\[}}%[[VAL_117]]#6, %[[VAL_113]]] %[[VAL_118]], %[[VAL_114]], %[[VAL_76]] : , <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> <256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_120:.*]] = ttg.local_alloc : () -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_121:.*]]:24 = scf.for %[[VAL_122:.*]] = %[[VAL_14]] to %[[VAL_45]] step %[[VAL_11]] iter_args(%[[VAL_123:.*]] = %[[VAL_78]], %[[VAL_124:.*]] = %[[VAL_117]]#0, %[[VAL_125:.*]] = %[[VAL_117]]#1, %[[VAL_126:.*]] = %[[VAL_117]]#2, %[[VAL_127:.*]] = %[[VAL_117]]#3, %[[VAL_128:.*]] = %[[VAL_117]]#4, %[[VAL_129:.*]] = %[[VAL_117]]#5, %[[VAL_130:.*]] = %[[VAL_117]]#6, %[[VAL_131:.*]] = %[[VAL_23]], %[[VAL_132:.*]] = %[[VAL_10]], %[[VAL_133:.*]] = %[[VAL_11]], %[[VAL_134:.*]] = %[[VAL_13]], %[[VAL_135:.*]] = %[[VAL_14]], %[[VAL_136:.*]] = %[[VAL_117]]#7, %[[VAL_137:.*]] = %[[VAL_117]]#8, %[[VAL_138:.*]] = %[[VAL_117]]#9, %[[VAL_139:.*]] = %[[VAL_14]], %[[VAL_140:.*]] = %[[VAL_78]], %[[VAL_141:.*]] = %[[VAL_36]], %[[VAL_142:.*]] = %[[VAL_117]]#2, %[[VAL_143:.*]] = %[[VAL_73]]#0, %[[VAL_144:.*]] = %[[VAL_117]]#5, %[[VAL_145:.*]] = %[[VAL_73]]#1, %[[VAL_146:.*]] = %[[VAL_117]]#6) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) : i32 { +// CHECK: %[[VAL_147:.*]] = arith.subi %[[VAL_45]], %[[VAL_8]] : i32 +// CHECK: %[[VAL_148:.*]] = arith.cmpi slt, %[[VAL_122]], %[[VAL_147]] : i32 +// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_123]], %[[VAL_46]] : i32 +// CHECK: %[[VAL_150:.*]] = arith.addi %[[VAL_123]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_151:.*]] = arith.select %[[VAL_149]], %[[VAL_14]], %[[VAL_150]] : i32 +// CHECK: %[[VAL_152:.*]] = arith.cmpi eq, %[[VAL_151]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_153:.*]] = arith.andi %[[VAL_148]], %[[VAL_152]] : i1 +// CHECK: %[[VAL_154:.*]]:10 = scf.if %[[VAL_153]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32) { +// CHECK: %[[VAL_155:.*]] = arith.addi %[[VAL_128]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_156:.*]] = arith.cmpi eq, %[[VAL_155]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_157:.*]] = arith.select %[[VAL_156]], %[[VAL_14]], %[[VAL_155]] : i32 +// CHECK: %[[VAL_158:.*]]:6 = scf.if %[[VAL_156]] -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32) { +// CHECK: %[[VAL_159:.*]] = tt.addptr %[[VAL_0]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_160:.*]] = arith.muli %[[VAL_136]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_161:.*]] = tt.addptr %[[VAL_47]], %[[VAL_160]] : !tt.ptr, i32 +// CHECK: %[[VAL_162:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_163:.*]] = arith.shrsi %[[VAL_162]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_161]], %[[VAL_159]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_5]], %[[VAL_3]]], {{\[}}%[[VAL_163]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_161]] : !tt.ptr +// CHECK: %[[VAL_164:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_161]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_165:.*]] = arith.addi %[[VAL_136]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_166:.*]] = arith.cmpi slt, %[[VAL_165]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_167:.*]] = arith.select %[[VAL_166]], %[[VAL_165]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_168:.*]] = tt.addptr %[[VAL_1]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_169:.*]] = arith.muli %[[VAL_137]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_170:.*]] = tt.addptr %[[VAL_48]], %[[VAL_169]] : !tt.ptr, i32 +// CHECK: %[[VAL_171:.*]] = arith.muli %[[VAL_32]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_172:.*]] = arith.shrsi %[[VAL_171]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_170]], %[[VAL_168]], {{\[}}%[[VAL_18]], %[[VAL_17]]], {{\[}}%[[VAL_5]], %[[VAL_4]]], {{\[}}%[[VAL_172]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_170]] : !tt.ptr +// CHECK: %[[VAL_173:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_170]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_174:.*]] = arith.addi %[[VAL_137]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_175:.*]] = arith.cmpi slt, %[[VAL_174]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_176:.*]] = arith.select %[[VAL_175]], %[[VAL_174]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_177:.*]] = tt.addptr %[[VAL_2]], %[[VAL_44]] : !tt.ptr, i32 +// CHECK: %[[VAL_178:.*]] = arith.muli %[[VAL_138]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_179:.*]] = tt.addptr %[[VAL_49]], %[[VAL_178]] : !tt.ptr, i32 +// CHECK: %[[VAL_180:.*]] = arith.muli %[[VAL_35]], %[[VAL_7]] : i64 +// CHECK: %[[VAL_181:.*]] = arith.shrsi %[[VAL_180]], %[[VAL_6]] : i64 +// CHECK: tt.experimental_tensormap_create %[[VAL_179]], %[[VAL_177]], {{\[}}%[[VAL_18]], %[[VAL_16]]], {{\[}}%[[VAL_4]], %[[VAL_3]]], {{\[}}%[[VAL_181]]], {{\[}}%[[VAL_11]], %[[VAL_11]]] {elem_type = 1 : i32, fill_mode = 0 : i32, interleave_layout = 0 : i32, swizzle_mode = 3 : i32} : (!tt.ptr, !tt.ptr, i32, i32, i32, i32, i64, i32, i32) -> () +// CHECK: tt.experimental_tensormap_fenceproxy_acquire %[[VAL_179]] : !tt.ptr +// CHECK: %[[VAL_182:.*]] = tt.reinterpret_tensor_descriptor %[[VAL_179]] : !tt.ptr to !tt.tensordesc> +// CHECK: %[[VAL_183:.*]] = arith.addi %[[VAL_138]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_184:.*]] = arith.cmpi slt, %[[VAL_183]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_185:.*]] = arith.select %[[VAL_184]], %[[VAL_183]], %[[VAL_14]] : i32 +// CHECK: scf.yield %[[VAL_164]], %[[VAL_173]], %[[VAL_182]], %[[VAL_167]], %[[VAL_176]], %[[VAL_185]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_124]], %[[VAL_125]], %[[VAL_126]], %[[VAL_136]], %[[VAL_137]], %[[VAL_138]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_186:.*]] = arith.addi %[[VAL_127]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_187:.*]] = arith.divsi %[[VAL_186]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_188:.*]] = arith.muli %[[VAL_187]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_189:.*]] = arith.subi %[[VAL_26]], %[[VAL_188]] : i32 +// CHECK: %[[VAL_190:.*]] = arith.minsi %[[VAL_189]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_191:.*]] = arith.remsi %[[VAL_186]], %[[VAL_190]] : i32 +// CHECK: %[[VAL_192:.*]] = arith.addi %[[VAL_188]], %[[VAL_191]] : i32 +// CHECK: %[[VAL_193:.*]] = arith.remsi %[[VAL_186]], %[[VAL_43]] : i32 +// CHECK: %[[VAL_194:.*]] = arith.divsi %[[VAL_193]], %[[VAL_190]] : i32 +// CHECK: %[[VAL_195:.*]] = arith.muli %[[VAL_192]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_196:.*]] = arith.muli %[[VAL_194]], %[[VAL_17]] : i32 +// CHECK: scf.yield %[[VAL_197:.*]]#0, %[[VAL_197]]#1, %[[VAL_197]]#2, %[[VAL_186]], %[[VAL_157]], %[[VAL_195]], %[[VAL_196]], %[[VAL_197]]#3, %[[VAL_197]]#4, %[[VAL_197]]#5 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_124]], %[[VAL_125]], %[[VAL_126]], %[[VAL_127]], %[[VAL_128]], %[[VAL_129]], %[[VAL_130]], %[[VAL_136]], %[[VAL_137]], %[[VAL_138]] : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_198:.*]] = arith.addi %[[VAL_134]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_199:.*]] = arith.cmpi slt, %[[VAL_198]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_200:.*]] = arith.select %[[VAL_199]], %[[VAL_198]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_201:.*]] = arith.xori %[[VAL_135]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_202:.*]] = arith.select %[[VAL_199]], %[[VAL_135]], %[[VAL_201]] : i32 +// CHECK: %[[VAL_203:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_200]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.wait_barrier %[[VAL_203]], %[[VAL_202]] : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_204:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_200]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_205:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_200]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_206:.*]] = ttg.memdesc_trans %[[VAL_204]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_207:.*]] = ttng.warp_group_dot %[[VAL_205]], %[[VAL_206]], %[[VAL_131]], %[[VAL_132]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> * !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_208:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_207]], %[[VAL_205]], %[[VAL_206]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_209:.*]] = arith.addi %[[VAL_133]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_210:.*]] = arith.cmpi slt, %[[VAL_209]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_211:.*]] = arith.select %[[VAL_210]], %[[VAL_209]], %[[VAL_14]] : i32 +// CHECK: %[[VAL_212:.*]] = arith.muli %[[VAL_151]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_213:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_211]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.barrier_expect %[[VAL_213]], 49152, %[[VAL_148]] : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_214:.*]] = ttg.memdesc_subview %[[VAL_50]]{{\[}}%[[VAL_211]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_215:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_216:.*]]#0 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_215]]{{\[}}%[[VAL_216]]#5, %[[VAL_212]]] %[[VAL_214]], %[[VAL_213]], %[[VAL_148]] : , <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> <128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_217:.*]] = ttg.memdesc_subview %[[VAL_51]]{{\[}}%[[VAL_211]], %[[VAL_14]], %[[VAL_14]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_218:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_216]]#1 : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_218]]{{\[}}%[[VAL_216]]#6, %[[VAL_212]]] %[[VAL_217]], %[[VAL_213]], %[[VAL_148]] : , <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> <256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_219:.*]] = arith.cmpi eq, %[[VAL_139]], %[[VAL_46]] : i32 +// CHECK: %[[VAL_220:.*]] = arith.cmpi ne, %[[VAL_139]], %[[VAL_46]] : i32 +// CHECK: scf.if %[[VAL_219]] { +// CHECK: %[[VAL_221:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_208]]#0, %[[VAL_205]], %[[VAL_206]] {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_222:.*]] = arith.truncf %[[VAL_221]]#0 : tensor<128x256xf32, #[[$ATTR_0]]> to tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} +// CHECK: ttg.local_store %[[VAL_222]], %[[VAL_120]] : tensor<128x256xf16, #[[$ATTR_0]]> -> !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: ttng.fence_async_shared {bCluster = false} +// CHECK: %[[VAL_223:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_141]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_local_to_global %[[VAL_223]]{{\[}}%[[VAL_143]], %[[VAL_145]]] %[[VAL_120]] : , <128x256xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: } +// CHECK: scf.yield %[[VAL_151]], %[[VAL_216]]#0, %[[VAL_216]]#1, %[[VAL_216]]#2, %[[VAL_216]]#3, %[[VAL_216]]#4, %[[VAL_216]]#5, %[[VAL_216]]#6, %[[VAL_208]]#0, %[[VAL_220]], %[[VAL_211]], %[[VAL_200]], %[[VAL_202]], %[[VAL_216]]#7, %[[VAL_216]]#8, %[[VAL_216]]#9, %[[VAL_140]], %[[VAL_151]], %[[VAL_142]], %[[VAL_216]]#2, %[[VAL_144]], %[[VAL_216]]#5, %[[VAL_146]], %[[VAL_216]]#6 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #[[$ATTR_0]]>, i1, i32, i32, i32, i32, i32, i32, i32, i32, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 +// CHECK: } +// CHECK: ttng.async_tma_store_wait {pendings = 0 : i32} +// CHECK: ttg.local_dealloc %[[VAL_120]] : !ttg.memdesc<128x256xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_224:.*]] = ttng.warp_group_dot_wait %[[VAL_225:.*]]#8 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_226:.*]] = ttg.async_wait {num = 0 : i32} +// CHECK: %[[VAL_227:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_14]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.inval_barrier %[[VAL_227]] : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_228:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_11]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.inval_barrier %[[VAL_228]] : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_229:.*]] = ttg.memdesc_subview %[[VAL_52]]{{\[}}%[[VAL_8]]] : !ttg.memdesc<3xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttng.inval_barrier %[[VAL_229]] : <1xi64, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttg.local_dealloc %[[VAL_50]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: ttg.local_dealloc %[[VAL_51]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_1]], #ttg.shared_memory, mutable> +// CHECK: tt.return +// CHECK: } +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %11 = arith.extsi %arg4 : i32 to i64 + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %13 = arith.divsi %7, %c132_i32 : i32 + %14 = arith.remsi %7, %c132_i32 : i32 + %15 = arith.cmpi slt, %0, %14 : i32 + %16 = scf.if %15 -> (i32) { + %23 = arith.addi %13, %c1_i32 : i32 + scf.yield %23 : i32 + } else { + scf.yield %13 : i32 + } + %17 = arith.subi %0, %c132_i32 : i32 + %18 = arith.muli %4, %c8_i32 : i32 + %19 = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 + %20 = arith.muli %6, %16 : i32 + %21 = arith.subi %6, %c1_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %37 = arith.addi %arg12, %c1_i32 : i32 + %38 = arith.cmpi eq, %37, %c1_i32 : i32 + %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } + %40 = arith.addi %arg11, %c132_i32 : i32 + %41 = arith.divsi %40, %18 : i32 + %42 = arith.muli %41, %c8_i32 : i32 + %43 = arith.subi %2, %42 : i32 + %44 = arith.minsi %43, %c8_i32 : i32 + %45 = arith.remsi %40, %44 : i32 + %46 = arith.addi %42, %45 : i32 + %47 = arith.remsi %40, %18 : i32 + %48 = arith.divsi %47, %44 : i32 + %49 = arith.muli %46, %c128_i32 : i32 + %50 = arith.muli %48, %c256_i32 : i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + %29 = tt.experimental_descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = tt.experimental_descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %35 = arith.cmpi eq, %25, %21 {loop.cluster = 3 : i32, loop.stage = 2 : i32} : i32 + %36 = scf.if %35 -> (i1) { + %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + scf.yield %false : i1 + } else { + scf.yield %true : i1 + } {loop.cluster = 3 : i32, loop.stage = 2 : i32} + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + } + tt.return + } +} diff --git a/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in b/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in new file mode 100644 index 000000000000..4ed5dab9cdf5 --- /dev/null +++ b/test/TritonGPU/samples/simulated-grouped-gemm.mlir.in @@ -0,0 +1,104 @@ +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/simulated-grouped-gemm.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/simulated-grouped-gemm.mlir.in \ +// -o test/TritonGPU/samples/simulated-grouped-gemm.mlir +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_descriptor_persistent(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1_i32 = arith.constant 1 : i32 + %c132_i32 = arith.constant 132 : i32 + %c-1_i32 = arith.constant -1 : i32 + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.addi %arg5, %c63_i32 : i32 + %6 = arith.divsi %5, %c64_i32 : i32 + %7 = arith.muli %2, %4 : i32 + %8 = arith.extsi %arg5 : i32 to i64 + %9 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%8, %c1_i64] : , > + %10 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%8, %c1_i64] : , > + %11 = arith.extsi %arg4 : i32 to i64 + %12 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%11, %c1_i64] : , > + %13 = arith.divsi %7, %c132_i32 : i32 + %14 = arith.remsi %7, %c132_i32 : i32 + %15 = arith.cmpi slt, %0, %14 : i32 + %16 = scf.if %15 -> (i32) { + %23 = arith.addi %13, %c1_i32 : i32 + scf.yield %23 : i32 + } else { + scf.yield %13 : i32 + } + %17 = arith.subi %0, %c132_i32 : i32 + %18 = arith.muli %4, %c8_i32 : i32 + %19 = tt.elementwise_inline_asm "mov.b32 $0, 0;" {constraints = "=r", packed_element = 1 : i32, pure = true} -> i32 + %20 = arith.muli %6, %16 : i32 + %21 = arith.subi %6, %c1_i32 : i32 + %true = arith.constant true + %false = arith.constant false + %22:10 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %c-1_i32, %arg8 = %9, %arg9 = %10, %arg10 = %12, %arg11 = %17, %arg12 = %c-1_i32, %arg13 = %c0_i32, %arg14 = %c0_i32, %arg15 = %cst, %arg16 = %false) -> (i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1) : i32 { + %23 = arith.cmpi eq, %arg7, %21 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %24 = arith.addi %arg7, %c1_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %25 = arith.select %23, %c0_i32, %24 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %26 = arith.cmpi eq, %25, %c0_i32 {loop.cluster = 0 : i32, loop.stage = 0 : i32} : i32 + %27:7 = scf.if %26 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32) { + %37 = arith.addi %arg12, %c1_i32 : i32 + %38 = arith.cmpi eq, %37, %c1_i32 : i32 + %39:4 = scf.if %38 -> (!tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32) { + %51 = tt.addptr %arg0, %19 : !tt.ptr, i32 + %52 = tt.make_tensor_descriptor %51, [%arg3, %arg5], [%8, %c1_i64] : , > + %53 = tt.addptr %arg1, %19 : !tt.ptr, i32 + %54 = tt.make_tensor_descriptor %53, [%arg4, %arg5], [%8, %c1_i64] : , > + %55 = tt.addptr %arg2, %19 : !tt.ptr, i32 + %56 = tt.make_tensor_descriptor %55, [%arg3, %arg4], [%11, %c1_i64] : , > + scf.yield %52, %54, %56, %c0_i32 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %37 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32 + } + %40 = arith.addi %arg11, %c132_i32 : i32 + %41 = arith.divsi %40, %18 : i32 + %42 = arith.muli %41, %c8_i32 : i32 + %43 = arith.subi %2, %42 : i32 + %44 = arith.minsi %43, %c8_i32 : i32 + %45 = arith.remsi %40, %44 : i32 + %46 = arith.addi %42, %45 : i32 + %47 = arith.remsi %40, %18 : i32 + %48 = arith.divsi %47, %44 : i32 + %49 = arith.muli %46, %c128_i32 : i32 + %50 = arith.muli %48, %c256_i32 : i32 + scf.yield %39#0, %39#1, %39#2, %40, %39#3, %49, %50 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } else { + scf.yield %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14 : !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32 + } {loop.cluster = 0 : i32, loop.stage = 0 : i32} + %28 = arith.muli %25, %c64_i32 {loop.cluster = 2 : i32, loop.stage = 0 : i32} : i32 + %29 = tt.experimental_descriptor_load %27#0[%27#5, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %30 = ttg.local_alloc %29 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = tt.experimental_descriptor_load %27#1[%27#6, %28] {loop.cluster = 2 : i32, loop.stage = 0 : i32} : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %32 = ttg.local_alloc %31 {loop.cluster = 1 : i32, loop.stage = 2 : i32} : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %33 = ttg.memdesc_trans %32 {loop.cluster = 1 : i32, loop.stage = 2 : i32, order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %34 = ttng.warp_group_dot %30, %33, %arg15, %arg16 {inputPrecision = 0 : i32, loop.cluster = 1 : i32, loop.stage = 2 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %35 = arith.cmpi eq, %25, %21 {loop.cluster = 3 : i32, loop.stage = 2 : i32} : i32 + %36 = scf.if %35 -> (i1) { + %37 = arith.truncf %34 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %38 = ttg.convert_layout %37 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %27#2[%27#5, %27#6], %38 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + scf.yield %false : i1 + } else { + scf.yield %true : i1 + } {loop.cluster = 3 : i32, loop.stage = 2 : i32} + scf.yield %25, %27#0, %27#1, %27#2, %27#3, %27#4, %27#5, %27#6, %34, %36 : i32, !tt.tensordesc>, !tt.tensordesc>, !tt.tensordesc>, i32, i32, i32, i32, tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i1 + } + tt.return + } +} diff --git a/utils/generate-test-checks.py b/utils/generate-test-checks.py index d66c269eb59e..f6041f40dbe5 100755 --- a/utils/generate-test-checks.py +++ b/utils/generate-test-checks.py @@ -259,7 +259,7 @@ def main(): "file, respectively. The delimeter lines are identified by " "--source_delim_regex.", ) - parser.add_argument("--source_delim_regex", type=str, default="func @") + parser.add_argument("--source_delim_regex", type=str, default="module") parser.add_argument( "--starts_from_scope", type=int,