From 83cf4362bd87c74bf57e79a7e213b4301fa3f25c Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Sun, 1 Dec 2024 20:45:57 +0000 Subject: [PATCH] [TESTING] Add golden sample test for pipelining matmul with descriptors (#5289) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [TESTING] Add golden sample test for pipelining matmul with descriptors #### [PR chain](https://github.com/jlebar/git-pr-chain) 1. #5288 1. 👉 #5289 👈 **YOU ARE HERE** 1. #5290 ⚠️⚠️ Please **do not click the green "merge" button** unless you know what you're doing. This PR is part of a chain of PRs, and clicking the merge button will not merge it into master. ⚠️⚠️ --- .../samples/descriptor-matmul-pipeline.mlir | 175 ++++++++ .../descriptor-matmul-pipeline.mlir.in | 57 +++ utils/generate-test-checks.py | 402 ++++++++++++++++++ 3 files changed, 634 insertions(+) create mode 100644 test/TritonGPU/samples/descriptor-matmul-pipeline.mlir create mode 100644 test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in create mode 100755 utils/generate-test-checks.py diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir new file mode 100644 index 000000000000..4bcff281b5c0 --- /dev/null +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir @@ -0,0 +1,175 @@ +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. + +// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}> +// CHECK: #[[$ATTR_2:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], hasLeadingOffset = false}> +// CHECK: #[[$ATTR_4:.+]] = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +// CHECK-LABEL: tt.func public @matmul_kernel_with_descriptors( +// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_1:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_2:.*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[VAL_3:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_4:.*]]: i32 {tt.divisibility = 16 : i32}, %[[VAL_5:.*]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : i32 +// CHECK: %[[VAL_7:.*]] = arith.constant 3 : i32 +// CHECK: %[[VAL_8:.*]] = arith.constant -1 : i32 +// CHECK: %[[VAL_9:.*]] = arith.constant 8 : i32 +// CHECK: %[[VAL_10:.*]] = arith.constant 128 : i32 +// CHECK: %[[VAL_11:.*]] = arith.constant 256 : i32 +// CHECK: %[[VAL_12:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_13:.*]] = arith.constant 64 : i32 +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_15:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_16:.*]] = arith.constant 127 : i32 +// CHECK: %[[VAL_17:.*]] = arith.constant 255 : i32 +// CHECK: %[[VAL_18:.*]] = arith.constant 63 : i32 +// CHECK: %[[VAL_19:.*]] = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_20:.*]] = tt.get_program_id x : i32 +// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_3]], %[[VAL_16]] : i32 +// CHECK: %[[VAL_22:.*]] = arith.divsi %[[VAL_21]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_4]], %[[VAL_17]] : i32 +// CHECK: %[[VAL_24:.*]] = arith.divsi %[[VAL_23]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_25:.*]] = arith.muli %[[VAL_24]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_26:.*]] = arith.divsi %[[VAL_20]], %[[VAL_25]] : i32 +// CHECK: %[[VAL_27:.*]] = arith.muli %[[VAL_26]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_22]], %[[VAL_27]] : i32 +// CHECK: %[[VAL_29:.*]] = arith.minsi %[[VAL_28]], %[[VAL_9]] : i32 +// CHECK: %[[VAL_30:.*]] = arith.remsi %[[VAL_20]], %[[VAL_29]] : i32 +// CHECK: %[[VAL_31:.*]] = arith.addi %[[VAL_27]], %[[VAL_30]] : i32 +// CHECK: %[[VAL_32:.*]] = arith.remsi %[[VAL_20]], %[[VAL_25]] : i32 +// CHECK: %[[VAL_33:.*]] = arith.divsi %[[VAL_32]], %[[VAL_29]] : i32 +// CHECK: %[[VAL_34:.*]] = arith.extsi %[[VAL_5]] : i32 to i64 +// CHECK: %[[VAL_35:.*]] = tt.make_tensor_descriptor %[[VAL_0]], {{\[}}%[[VAL_3]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_36:.*]] = tt.make_tensor_descriptor %[[VAL_1]], {{\[}}%[[VAL_4]], %[[VAL_5]]], {{\[}}%[[VAL_34]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_37:.*]] = arith.extsi %[[VAL_4]] : i32 to i64 +// CHECK: %[[VAL_38:.*]] = tt.make_tensor_descriptor %[[VAL_2]], {{\[}}%[[VAL_3]], %[[VAL_4]]], {{\[}}%[[VAL_37]], %[[VAL_14]]] : , > +// CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_31]], %[[VAL_10]] : i32 +// CHECK: %[[VAL_40:.*]] = arith.muli %[[VAL_33]], %[[VAL_11]] : i32 +// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_5]], %[[VAL_18]] : i32 +// CHECK: %[[VAL_42:.*]] = arith.divsi %[[VAL_41]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_43:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_44:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_45:.*]] = ttg.local_alloc : () -> !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_46:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.init_barrier %[[VAL_46]], 1 : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_47:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.init_barrier %[[VAL_47]], 1 : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_48:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.init_barrier %[[VAL_48]], 1 : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_49:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_50:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.barrier_expect %[[VAL_50]], 49152, %[[VAL_49]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_51:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_52:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_52]]{{\[}}%[[VAL_39]], %[[VAL_12]]] %[[VAL_51]], %[[VAL_50]], %[[VAL_49]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_53:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_12]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_54:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_54]]{{\[}}%[[VAL_40]], %[[VAL_12]]] %[[VAL_53]], %[[VAL_50]], %[[VAL_49]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_55:.*]] = arith.cmpi sgt, %[[VAL_42]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_56:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.barrier_expect %[[VAL_56]], 49152, %[[VAL_55]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_57:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_58:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_58]]{{\[}}%[[VAL_39]], %[[VAL_13]]] %[[VAL_57]], %[[VAL_56]], %[[VAL_55]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_59:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_15]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_60:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_60]]{{\[}}%[[VAL_40]], %[[VAL_13]]] %[[VAL_59]], %[[VAL_56]], %[[VAL_55]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_61:.*]]:5 = scf.for %[[VAL_62:.*]] = %[[VAL_12]] to %[[VAL_42]] step %[[VAL_15]] iter_args(%[[VAL_63:.*]] = %[[VAL_19]], %[[VAL_64:.*]] = %[[VAL_13]], %[[VAL_65:.*]] = %[[VAL_15]], %[[VAL_66:.*]] = %[[VAL_8]], %[[VAL_67:.*]] = %[[VAL_12]]) -> (tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32) : i32 { +// CHECK: %[[VAL_68:.*]] = arith.subi %[[VAL_42]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_69:.*]] = arith.cmpi slt, %[[VAL_62]], %[[VAL_68]] : i32 +// CHECK: %[[VAL_70:.*]] = arith.addi %[[VAL_66]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_71:.*]] = arith.cmpi slt, %[[VAL_70]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_72:.*]] = arith.select %[[VAL_71]], %[[VAL_70]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_73:.*]] = arith.xori %[[VAL_67]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_74:.*]] = arith.select %[[VAL_71]], %[[VAL_67]], %[[VAL_73]] : i32 +// CHECK: %[[VAL_75:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_72]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.wait_barrier %[[VAL_75]], %[[VAL_74]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_76:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_77:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_72]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_78:.*]] = ttg.memdesc_trans %[[VAL_76]] {order = array} : !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_79:.*]] = ttng.warp_group_dot %[[VAL_77]], %[[VAL_78]], %[[VAL_63]] {inputPrecision = 0 : i32, isAsync = true} : !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> * !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #ttg.shared_memory, mutable> -> tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_80:.*]]:3 = ttng.warp_group_dot_wait %[[VAL_79]], %[[VAL_77]], %[[VAL_78]] {pendings = 1 : i32} : tensor<128x256xf32, #[[$ATTR_1]]>, !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable>, !ttg.memdesc<64x256xf16, #[[$ATTR_4]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_81:.*]] = arith.addi %[[VAL_64]], %[[VAL_13]] : i32 +// CHECK: %[[VAL_82:.*]] = arith.addi %[[VAL_65]], %[[VAL_15]] : i32 +// CHECK: %[[VAL_83:.*]] = arith.cmpi slt, %[[VAL_82]], %[[VAL_7]] : i32 +// CHECK: %[[VAL_84:.*]] = arith.select %[[VAL_83]], %[[VAL_82]], %[[VAL_12]] : i32 +// CHECK: %[[VAL_85:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_84]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.barrier_expect %[[VAL_85]], 49152, %[[VAL_69]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_86:.*]] = ttg.memdesc_subview %[[VAL_43]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_87:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_35]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_87]]{{\[}}%[[VAL_39]], %[[VAL_81]]] %[[VAL_86]], %[[VAL_85]], %[[VAL_69]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_88:.*]] = ttg.memdesc_subview %[[VAL_44]]{{\[}}%[[VAL_84]], %[[VAL_12]], %[[VAL_12]]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_89:.*]] = ttng.tensor_desc_to_tma_ptr %[[VAL_36]] : !tt.tensordesc> to !tt.ptr +// CHECK: ttng.async_tma_copy_global_to_local %[[VAL_89]]{{\[}}%[[VAL_40]], %[[VAL_81]]] %[[VAL_88]], %[[VAL_85]], %[[VAL_69]] : , <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> <256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: scf.yield %[[VAL_80]]#0, %[[VAL_81]], %[[VAL_84]], %[[VAL_72]], %[[VAL_74]] : tensor<128x256xf32, #[[$ATTR_1]]>, i32, i32, i32, i32 +// CHECK: } +// CHECK: %[[VAL_90:.*]] = ttng.warp_group_dot_wait %[[VAL_91:.*]]#0 {pendings = 0 : i32} : tensor<128x256xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_92:.*]] = ttg.async_wait {num = 0 : i32} +// CHECK: %[[VAL_93:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_12]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.inval_barrier %[[VAL_93]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_94:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_15]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.inval_barrier %[[VAL_94]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_95:.*]] = ttg.memdesc_subview %[[VAL_45]]{{\[}}%[[VAL_6]]] : !ttg.memdesc<3xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> -> !ttg.memdesc<1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttng.inval_barrier %[[VAL_95]] : <1xi64, #[[$ATTR_3]], #ttg.shared_memory, mutable> +// CHECK: ttg.local_dealloc %[[VAL_43]] : !ttg.memdesc<3x128x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: ttg.local_dealloc %[[VAL_44]] : !ttg.memdesc<3x256x64xf16, #[[$ATTR_2]], #ttg.shared_memory, mutable> +// CHECK: %[[VAL_96:.*]] = arith.truncf %[[VAL_90]] : tensor<128x256xf32, #[[$ATTR_1]]> to tensor<128x256xf16, #[[$ATTR_1]]> +// CHECK: %[[VAL_97:.*]] = ttg.convert_layout %[[VAL_96]] : tensor<128x256xf16, #[[$ATTR_1]]> -> tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: tt.experimental_descriptor_store %[[VAL_38]]{{\[}}%[[VAL_39]], %[[VAL_40]]], %[[VAL_97]] : !tt.tensordesc>, tensor<128x256xf16, #[[$ATTR_0]]> +// CHECK: tt.return +// CHECK: } +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_with_descriptors(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.extsi %arg5 : i32 to i64 + %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , > + %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , > + %17 = arith.extsi %arg4 : i32 to i64 + %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , > + %19 = arith.muli %11, %c128_i32 : i32 + %20 = arith.muli %13, %c256_i32 : i32 + %21 = arith.addi %arg5, %c63_i32 : i32 + %22 = arith.divsi %21, %c64_i32 : i32 + %23:2 = scf.for %arg6 = %c0_i32 to %22 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32) : i32 { + %26 = tt.experimental_descriptor_load %15[%19, %arg8] : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %27 = ttg.local_alloc %26 : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %28 = tt.experimental_descriptor_load %16[%20, %arg8] : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %29 = ttg.local_alloc %28 : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %30 = ttg.memdesc_trans %29 {order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = ttng.warp_group_dot %27, %30, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %32 = arith.addi %arg8, %c64_i32 : i32 + scf.yield %31, %32 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32 + } + %24 = arith.truncf %23#0 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %25 = ttg.convert_layout %24 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %18[%19, %20], %25 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.return + } +} diff --git a/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in new file mode 100644 index 000000000000..5175c1a6cefa --- /dev/null +++ b/test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in @@ -0,0 +1,57 @@ +// To regenerate this test case, run the command +// triton-opt test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | \ +// utils/generate-test-checks.py --source test/TritonGPU/samples/descriptor-matmul-pipeline.mlir.in --source_delim_regex="module" \ +// -o test/TritonGPU/samples/descriptor-matmul-pipeline.mlir +// RUN: triton-opt %s -split-input-file -tritongpu-loop-scheduling -tritongpu-pipeline -canonicalize | FileCheck --dump-input-context=50 %s +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_with_descriptors(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c8_i32 = arith.constant 8 : i32 + %c128_i32 = arith.constant 128 : i32 + %c256_i32 = arith.constant 256 : i32 + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %c1_i64 = arith.constant 1 : i64 + %c1_i32 = arith.constant 1 : i32 + %c127_i32 = arith.constant 127 : i32 + %c255_i32 = arith.constant 255 : i32 + %c63_i32 = arith.constant 63 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %0 = tt.get_program_id x : i32 + %1 = arith.addi %arg3, %c127_i32 : i32 + %2 = arith.divsi %1, %c128_i32 : i32 + %3 = arith.addi %arg4, %c255_i32 : i32 + %4 = arith.divsi %3, %c256_i32 : i32 + %5 = arith.muli %4, %c8_i32 : i32 + %6 = arith.divsi %0, %5 : i32 + %7 = arith.muli %6, %c8_i32 : i32 + %8 = arith.subi %2, %7 : i32 + %9 = arith.minsi %8, %c8_i32 : i32 + %10 = arith.remsi %0, %9 : i32 + %11 = arith.addi %7, %10 : i32 + %12 = arith.remsi %0, %5 : i32 + %13 = arith.divsi %12, %9 : i32 + %14 = arith.extsi %arg5 : i32 to i64 + %15 = tt.make_tensor_descriptor %arg0, [%arg3, %arg5], [%14, %c1_i64] : , > + %16 = tt.make_tensor_descriptor %arg1, [%arg4, %arg5], [%14, %c1_i64] : , > + %17 = arith.extsi %arg4 : i32 to i64 + %18 = tt.make_tensor_descriptor %arg2, [%arg3, %arg4], [%17, %c1_i64] : , > + %19 = arith.muli %11, %c128_i32 : i32 + %20 = arith.muli %13, %c256_i32 : i32 + %21 = arith.addi %arg5, %c63_i32 : i32 + %22 = arith.divsi %21, %c64_i32 : i32 + %23:2 = scf.for %arg6 = %c0_i32 to %22 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %c0_i32) -> (tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32) : i32 { + %26 = tt.experimental_descriptor_load %15[%19, %arg8] : !tt.tensordesc> -> tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %27 = ttg.local_alloc %26 : (tensor<128x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %28 = tt.experimental_descriptor_load %16[%20, %arg8] : !tt.tensordesc> -> tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>> + %29 = ttg.local_alloc %28 : (tensor<256x64xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>>) -> !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> + %30 = ttg.memdesc_trans %29 {order = array} : !ttg.memdesc<256x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> -> !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> + %31 = ttng.warp_group_dot %27, %30, %arg7 {inputPrecision = 0 : i32} : !ttg.memdesc<128x64xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>, #ttg.shared_memory> * !ttg.memdesc<64x256xf16, #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}>, #ttg.shared_memory> -> tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %32 = arith.addi %arg8, %c64_i32 : i32 + scf.yield %31, %32 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>>, i32 + } + %24 = arith.truncf %23#0 : tensor<128x256xf32, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> to tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> + %25 = ttg.convert_layout %24 : tensor<128x256xf16, #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 256, 16]}>> -> tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.experimental_descriptor_store %18[%19, %20], %25 : !tt.tensordesc>, tensor<128x256xf16, #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 8], order = [1, 0]}>> + tt.return + } +} diff --git a/utils/generate-test-checks.py b/utils/generate-test-checks.py new file mode 100755 index 000000000000..d66c269eb59e --- /dev/null +++ b/utils/generate-test-checks.py @@ -0,0 +1,402 @@ +#!/usr/bin/env python3 +"""A script to generate FileCheck statements for mlir unit tests. + +This script is a utility to add FileCheck patterns to an mlir file. + +NOTE: The input .mlir is expected to be the output from the parser, not a +stripped down variant. + +Example usage: +$ generate-test-checks.py foo.mlir +$ mlir-opt foo.mlir -transformation | generate-test-checks.py +$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir +$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i +$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @' + +The script will heuristically generate CHECK/CHECK-LABEL commands for each line +within the file. By default this script will also try to insert string +substitution blocks for all SSA value names. If --source file is specified, the +script will attempt to insert the generated CHECKs to the source file by looking +for line positions matched by --source_delim_regex. + +The script is designed to make adding checks to a test case fast, it is *not* +designed to be authoritative about what constitutes a good test! +""" + +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import os # Used to advertise this file's name ("autogenerated_note"). +import re +import sys +from typing import Optional + +ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by " +ADVERT_END = """ +// The script is designed to make adding checks to +// a test case fast, it is *not* designed to be authoritative +// about what constitutes a good test! The CHECK should be +// minimized and named to reflect the test intent. +""" + +# Regex command to match an SSA identifier. +SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*" +SSA_RE = re.compile(SSA_RE_STR) + +# Regex matching the left-hand side of an assignment +SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*=' +SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR) + +# Regex matching attributes +ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)' +ATTR_RE = re.compile(ATTR_RE_STR) + +# Regex matching the left-hand side of an attribute definition +ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*=' +ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR) + + +# Class used to generate and manage string substitution blocks for SSA value +# names. +class VariableNamer: + + def __init__(self, variable_names): + self.scopes = [] + self.name_counter = 0 + + # Number of variable names to still generate in parent scope + self.generate_in_parent_scope_left = 0 + + # Parse variable names + self.variable_names = [name.upper() for name in variable_names.split(',')] + self.used_variable_names = set() + + # Generate the following 'n' variable names in the parent scope. + def generate_in_parent_scope(self, n): + self.generate_in_parent_scope_left = n + + # Generate a substitution name for the given ssa value name. + def generate_name(self, source_variable_name): + + # Compute variable name + variable_name = self.variable_names.pop(0) if len(self.variable_names) > 0 else '' + if variable_name == '': + variable_name = "VAL_" + str(self.name_counter) + self.name_counter += 1 + + # Scope where variable name is saved + scope = len(self.scopes) - 1 + if self.generate_in_parent_scope_left > 0: + self.generate_in_parent_scope_left -= 1 + scope = len(self.scopes) - 2 + assert (scope >= 0) + + # Save variable + if variable_name in self.used_variable_names: + raise RuntimeError(variable_name + ': duplicate variable name') + self.scopes[scope][source_variable_name] = variable_name + self.used_variable_names.add(variable_name) + + return variable_name + + # Push a new variable name scope. + def push_name_scope(self): + self.scopes.append({}) + + # Pop the last variable name scope. + def pop_name_scope(self): + self.scopes.pop() + + # Return the level of nesting (number of pushed scopes). + def num_scopes(self): + return len(self.scopes) + + # Reset the counter and used variable names. + def clear_names(self): + self.name_counter = 0 + self.used_variable_names = set() + + +class AttributeNamer: + + def __init__(self, attribute_names): + self.name_counter = 0 + self.attribute_names = [name.upper() for name in attribute_names.split(',')] + self.map = {} + self.used_attribute_names = set() + + # Generate a substitution name for the given attribute name. + def generate_name(self, source_attribute_name): + + # Compute FileCheck name + attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else '' + if attribute_name == '': + attribute_name = "ATTR_" + str(self.name_counter) + self.name_counter += 1 + + # Prepend global symbol + attribute_name = '$' + attribute_name + + # Save attribute + if attribute_name in self.used_attribute_names: + raise RuntimeError(attribute_name + ': duplicate attribute name') + self.map[source_attribute_name] = attribute_name + self.used_attribute_names.add(attribute_name) + return attribute_name + + # Get the saved substitution name for the given attribute name, if it exists. + def get_name(self, source_attribute_name) -> Optional[str]: + return self.map.get(source_attribute_name) + + +# Return the number of SSA results in a line of type +# %0, %1, ... = ... +# The function returns 0 if there are no results. +def get_num_ssa_results(input_line): + m = SSA_RESULTS_RE.match(input_line) + return m.group().count('%') if m else 0 + + +# Process a line of input that has been split at each SSA identifier '%'. +def process_line(line_chunks, variable_namer): + output_line = "" + + # Process the rest that contained an SSA value name. + for chunk in line_chunks: + m = SSA_RE.match(chunk) + ssa_name = m.group(0) if m is not None else '' + + # Check if an existing variable exists for this name. + variable = None + for scope in variable_namer.scopes: + variable = scope.get(ssa_name) + if variable is not None: + break + + # If one exists, then output the existing name. + if variable is not None: + output_line += "%[[" + variable + "]]" + else: + # Otherwise, generate a new variable. + variable = variable_namer.generate_name(ssa_name) + output_line += "%[[" + variable + ":.*]]" + + # Append the non named group. + output_line += chunk[len(ssa_name):] + + return output_line.rstrip() + "\n" + + +# Process the source file lines. The source file doesn't have to be .mlir. +def process_source_lines(source_lines, note, args): + source_split_re = re.compile(args.source_delim_regex) + + source_segments = [[]] + for line in source_lines: + # Remove previous note. + if line == note: + continue + # Remove previous CHECK lines. + if line.find(args.check_prefix) != -1: + continue + # Segment the file based on --source_delim_regex. + if source_split_re.search(line): + source_segments.append([]) + + source_segments[-1].append(line + "\n") + return source_segments + + +def process_attribute_definition(line, attribute_namer, output): + m = ATTR_DEF_RE.match(line) + if m: + attribute_name = attribute_namer.generate_name(m.group(1)) + line = '// CHECK: #[[' + attribute_name + ':.+]] =' + line[len(m.group(0)):] + '\n' + output.append(line) + + +def process_attribute_references(line, attribute_namer): + + output_line = '' + components = ATTR_RE.split(line) + for component in components: + m = ATTR_RE.match(component) + name = attribute_namer.get_name(m.group(1)) if m else None + if name is None: + output_line += component + else: + output_line += '#[[' + name + ']]' + output_line += component[len(m.group()):] + return output_line + + +# Pre-process a line of input to remove any character sequences that will be +# problematic with FileCheck. +def preprocess_line(line): + # Replace any double brackets, '[[' with escaped replacements. '[[' + # corresponds to variable names in FileCheck. + output_line = line.replace("[[", "{{\\[\\[}}") + + # Replace any single brackets that are followed by an SSA identifier, the + # identifier will be replace by a variable; Creating the same situation as + # above. + output_line = output_line.replace("[%", "{{\\[}}%") + + return output_line + + +def main(): + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("--check-prefix", default="CHECK", help="Prefix to use from check file.") + parser.add_argument("-o", "--output", nargs="?", type=argparse.FileType("w"), default=None) + parser.add_argument("input", nargs="?", type=argparse.FileType("r"), default=sys.stdin) + parser.add_argument( + "--source", + type=str, + help="Print each CHECK chunk before each delimeter line in the source" + "file, respectively. The delimeter lines are identified by " + "--source_delim_regex.", + ) + parser.add_argument("--source_delim_regex", type=str, default="func @") + parser.add_argument( + "--starts_from_scope", + type=int, + default=1, + help="Omit the top specified level of content. For example, by default " + 'it omits "module {"', + ) + parser.add_argument("-i", "--inplace", action="store_true", default=False) + parser.add_argument( + "--variable_names", type=str, default='', + help="Names to be used in FileCheck regular expression to represent SSA " + "variables in the order they are encountered. Separate names with commas, " + "and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')") + parser.add_argument( + "--attribute_names", type=str, default='', help="Names to be used in FileCheck regular expression to represent " + "attributes in the order they are defined. Separate names with commas," + "commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')") + + args = parser.parse_args() + + # Open the given input file. + input_lines = [l.rstrip() for l in args.input] + args.input.close() + + # Generate a note used for the generated check file. + script_name = os.path.basename(__file__) + autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END + + source_segments = None + if args.source: + source_segments = process_source_lines([l.rstrip() for l in open(args.source, "r")], autogenerated_note, args) + + if args.inplace: + assert args.output is None + output = open(args.source, "w") + elif args.output is None: + output = sys.stdout + else: + output = args.output + + output_segments = [[]] + + # Namers + variable_namer = VariableNamer(args.variable_names) + attribute_namer = AttributeNamer(args.attribute_names) + + # Process lines + for input_line in input_lines: + if not input_line: + continue + + # Check if this is an attribute definition and process it + process_attribute_definition(input_line, attribute_namer, output_segments[-1]) + + # Lines with blocks begin with a ^. These lines have a trailing comment + # that needs to be stripped. + lstripped_input_line = input_line.lstrip() + is_block = lstripped_input_line[0] == "^" + if is_block: + input_line = input_line.rsplit("//", 1)[0].rstrip() + + cur_level = variable_namer.num_scopes() + + # If the line starts with a '}', pop the last name scope. + if lstripped_input_line[0] == "}": + variable_namer.pop_name_scope() + cur_level = variable_namer.num_scopes() + + # If the line ends with a '{', push a new name scope. + if input_line[-1] == "{": + variable_namer.push_name_scope() + if cur_level == args.starts_from_scope: + output_segments.append([]) + + # Result SSA values must still be pushed to parent scope + num_ssa_results = get_num_ssa_results(input_line) + variable_namer.generate_in_parent_scope(num_ssa_results) + + # Omit lines at the near top level e.g. "module {". + if cur_level < args.starts_from_scope: + continue + + if len(output_segments[-1]) == 0: + variable_namer.clear_names() + + # Preprocess the input to remove any sequences that may be problematic with + # FileCheck. + input_line = preprocess_line(input_line) + + # Process uses of attributes in this line + input_line = process_attribute_references(input_line, attribute_namer) + + # Split the line at the each SSA value name. + ssa_split = input_line.split("%") + + # If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'. + if len(output_segments[-1]) != 0 or not ssa_split[0]: + output_line = "// " + args.check_prefix + ": " + # Pad to align with the 'LABEL' statements. + output_line += " " * len("-LABEL") + + # Output the first line chunk that does not contain an SSA name. + output_line += ssa_split[0] + + # Process the rest of the input line. + output_line += process_line(ssa_split[1:], variable_namer) + + else: + # Output the first line chunk that does not contain an SSA name for the + # label. + output_line = "// " + args.check_prefix + "-LABEL: " + ssa_split[0] + "\n" + + # Process the rest of the input line on separate check lines. + output_line += "// " + args.check_prefix + "-SAME: " + output_line += process_line(ssa_split[1:], variable_namer) + + # Append the output line. + output_segments[-1].append(output_line) + + output.write(autogenerated_note + "\n") + + # Write the output. + if source_segments: + assert len(output_segments) == len(source_segments), (len(output_segments), len(source_segments)) + for check_segment, source_segment in zip(output_segments, source_segments): + for line in check_segment: + output.write(line) + for line in source_segment: + output.write(line) + else: + for segment in output_segments: + output.write("\n") + for output_line in segment: + output.write(output_line) + output.write("\n") + output.close() + + +if __name__ == "__main__": + main()