Skip to content

Commit

Permalink
Add ReturnLike to terminator op
Browse files Browse the repository at this point in the history
Signed-off-by: philass <[email protected]>
  • Loading branch information
philass committed Nov 8, 2023
1 parent bd26475 commit 6ba10ca
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 52 deletions.
3 changes: 2 additions & 1 deletion src/Dialect/Krnl/Krnl.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include "mlir/IR/OpBase.td"
include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
Expand Down Expand Up @@ -285,7 +286,7 @@ def KrnlSeqStoreOp : Op<Krnl_Dialect, "seqstore", [MemRefsNormalizable]> {
Index:$index);
}

def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [ReturnLike, Terminator]> {
let summary = "Krnl terminator operation";
let description = [{
Krnl terminator is a special terminator operation for blocks inside krnl
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,66 +8,57 @@ func.func @test_sequence_erase(%arg0: !onnx.Seq<tensor<?x4x5xf32>>) -> tensor<3x
%4 = "onnx.SequenceAt"(%7, %0) : (!onnx.Seq<tensor<?x4x5xf32>>, tensor<i64>) -> tensor<?x4x5xf32>
%5 = "onnx.Shape"(%4) : (tensor<?x4x5xf32>) -> tensor<3xi64>
return %5 : tensor<3xi64>

// mlir2FileCheck.py
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 - 1)>
// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1] -> (s0)>
// CHECK-LABEL: func.func @test_sequence_erase
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?xmemref<?x4x5xf32>>) -> memref<3xi64> {
// CHECK-DAG: [[VAR_0_:%.+]] = "krnl.global"() {name = {{.*}}, shape = [], value = dense<0> : tensor<i64>} : () -> memref<i64>
// CHECK-DAG: [[VAR_c0_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_1_:%.+]] = memref.dim [[PARAM_0_]], [[VAR_c0_]] : memref<?xmemref<?x4x5xf32>>
// CHECK-DAG: [[VAR_c1_:%.+]] = arith.constant 1 : index
// CHECK: [[VAR_2_:%.+]] = affine.apply {{#map.*}}(){{.}}[[VAR_1_]]{{.}}
// CHECK-DAG: [[VAR_3_:%.+]] = "krnl.seqalloc"([[VAR_2_]]) : (index) -> memref<?xmemref<?x4x5xf32>>
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : i64
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [], value = dense<0> : tensor<i64>} : () -> memref<i64>
// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref<?xmemref<?x4x5xf32>>
// CHECK: [[VAR_1_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[VAR_2_:%.+]] = "krnl.seqalloc"([[VAR_1_]]) : (index) -> memref<?xmemref<?x4x5xf32>>
// CHECK-DAG: [[LOAD_VAR_0_MEM_:%.+]] = krnl.load [[VAR_0_]][] : memref<i64>
// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[LOAD_VAR_0_MEM_]] : i64 to index
// CHECK-DAG: [[VAR_6_:%.+]] = affine.apply {{#map.*}}(){{.}}[[VAR_1_]], [[VAR_5_]]{{.}}
// CHECK-DAG: [[VAR_c0_0_:%.+]] = arith.constant 0 : index
// CHECK: [[VAR_7_:%.+]] = arith.cmpi slt, [[VAR_5_]], [[VAR_c0_0_]] : index
// CHECK-DAG: [[VAR_8_:%.+]] = arith.select [[VAR_7_]], [[VAR_6_]], [[VAR_5_]] : index
// CHECK-DAG: [[VAR_c0_1_:%.+]] = arith.constant 0 : index
// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[LOAD_VAR_0_MEM_]] : i64 to index
// CHECK-DAG: [[VAR_5_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_dim_]], [[VAR_4_]]{{.}}
// CHECK-DAG: [[VAR_6_:%.+]] = arith.cmpi slt, [[VAR_4_]], [[CST_0_]] : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_7_:%.+]] = arith.select [[VAR_6_]], [[VAR_5_]], [[VAR_4_]] : index
// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[VAR_8_]]){
// CHECK: [[VAR_24_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_24_]]{{.}} : memref<?xmemref<?x4x5xf32>>
// CHECK: "krnl.seqstore"([[LOAD_PARAM_0_MEM_]], [[VAR_3_]], [[VAR_8_]]) : (memref<?x4x5xf32>, memref<?xmemref<?x4x5xf32>>, index) -> ()
// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to [[VAR_7_]]){
// CHECK: [[VAR_18_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_18_]]{{.}} : memref<?xmemref<?x4x5xf32>>
// CHECK: "krnl.seqstore"([[LOAD_PARAM_0_MEM_]], [[VAR_2_]], [[VAR_7_]]) : (memref<?x4x5xf32>, memref<?xmemref<?x4x5xf32>>, index) -> ()
// CHECK: }
// CHECK: [[VAR_c1_2_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_8_]], [[VAR_c1_2_]] : index
// CHECK-DAG: [[VAR_9_:%.+]] = arith.addi [[VAR_7_]], [[CST_1_]] : index
// CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = [[VAR_10_]] to {{#map.*}}(){{.}}[[VAR_1_]], [[VAR_5_]]{{.}}){
// CHECK: [[VAR_24_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_24_1_]]{{.}} : memref<?xmemref<?x4x5xf32>>
// CHECK-DAG: [[VAR_c1_8_:%.+]] = arith.constant 1 : index
// CHECK: [[VAR_26_:%.+]] = arith.subi [[VAR_24_1_]], [[VAR_c1_8_]] : index
// CHECK: "krnl.seqstore"([[LOAD_PARAM_0_MEM_1_]], [[VAR_3_]], [[VAR_26_]]) : (memref<?x4x5xf32>, memref<?xmemref<?x4x5xf32>>, index) -> ()
// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = [[VAR_9_]] to [[MAP_2_]](){{.}}[[VAR_dim_]], [[VAR_4_]]{{.}}){
// CHECK: [[VAR_18_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_18_1_]]{{.}} : memref<?xmemref<?x4x5xf32>>
// CHECK-DAG: [[VAR_20_:%.+]] = arith.subi [[VAR_18_1_]], [[CST_1_]] : index
// CHECK: "krnl.seqstore"([[LOAD_PARAM_0_MEM_1_]], [[VAR_2_]], [[VAR_2_]]0) : (memref<?x4x5xf32>, memref<?xmemref<?x4x5xf32>>, index) -> ()
// CHECK: }
// CHECK: [[VAR_c0_3_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[VAR_12_:%.+]] = memref.dim [[VAR_3_]], [[VAR_c0_3_]] : memref<?xmemref<?x4x5xf32>>
// CHECK-DAG: [[VAR_dim_0_:%.+]] = memref.dim [[VAR_2_]], [[CST_0_]] : memref<?xmemref<?x4x5xf32>>
// CHECK-DAG: [[LOAD_VAR_0_MEM_1_:%.+]] = krnl.load [[VAR_0_]][] : memref<i64>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_14_:%.+]] = arith.index_cast [[LOAD_VAR_0_MEM_1_]] : i64 to index
// CHECK-DAG: [[VAR_c0_4_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_15_:%.+]] = arith.cmpi slt, [[VAR_14_]], [[VAR_c0_4_]] : index
// CHECK-DAG: [[VAR_16_:%.+]] = affine.apply {{#map.+}}(){{.}}{{.*}}, {{.*}}{{.}}
// CHECK: [[VAR_17_:%.+]] = arith.select [[VAR_15_]], [[VAR_16_]], [[VAR_14_]] : index
// CHECK: [[VAR_18_:%.+]] = "krnl.seqextract"([[VAR_3_]], [[VAR_17_]]) {copy = 1 : ui1} : (memref<?xmemref<?x4x5xf32>>, index) -> memref<?x4x5xf32>
// CHECK-DAG: [[VAR_c3_:%.+]] = arith.constant 3 : index
// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[LOAD_VAR_0_MEM_1_]] : i64 to index
// CHECK-DAG: [[VAR_13_:%.+]] = arith.cmpi slt, [[VAR_12_]], [[CST_0_]] : index
// CHECK-DAG: [[VAR_14_:%.+]] = affine.apply [[MAP_1_]](){{.}}[[VAR_12_]], [[VAR_dim_0_]]{{.}}
// CHECK: [[VAR_15_:%.+]] = arith.select [[VAR_13_]], [[VAR_14_]], [[VAR_12_]] : index
// CHECK-DAG: [[VAR_16_:%.+]] = "krnl.seqextract"([[VAR_2_]], [[VAR_15_]]) {copy = 1 : ui1} : (memref<?xmemref<?x4x5xf32>>, index) -> memref<?x4x5xf32>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<3xi64>
// CHECK-DAG: [[VAR_c0_5_:%.+]] = arith.constant 0 : index
// CHECK: [[VAR_20_:%.+]] = memref.dim [[VAR_18_]], [[VAR_c0_5_]] : memref<?x4x5xf32>
// CHECK: memref.dealloc [[VAR_18_]] : memref<?x4x5xf32>
// CHECK: "krnl.seqdealloc"([[VAR_3_]]) : (memref<?xmemref<?x4x5xf32>>) -> ()
// CHECK-DAG: [[VAR_c4_:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[VAR_c5_:%.+]] = arith.constant 5 : index
// CHECK-DAG: [[VAR_21_:%.+]] = arith.index_cast [[VAR_20_]] : index to i64
// CHECK-DAG: [[VAR_c0_6_:%.+]] = arith.constant 0 : index
// CHECK: krnl.store [[VAR_21_]], [[RES_]]{{.}}[[VAR_c0_6_]]{{.}} : memref<3xi64>
// CHECK-DAG: [[VAR_22_:%.+]] = arith.index_cast [[VAR_c4_]] : index to i64
// CHECK-DAG: [[VAR_c1_7_:%.+]] = arith.constant 1 : index
// CHECK: krnl.store [[VAR_22_]], [[RES_]]{{.}}[[VAR_c1_7_]]{{.}} : memref<3xi64>
// CHECK-DAG: [[VAR_23_:%.+]] = arith.index_cast [[VAR_c5_]] : index to i64
// CHECK-DAG: [[VAR_c2_:%.+]] = arith.constant 2 : index
// CHECK: krnl.store [[VAR_23_]], [[RES_]]{{.}}[[VAR_c2_]]{{.}} : memref<3xi64>
// CHECK: [[VAR_dim_1_:%.+]] = memref.dim [[VAR_16_]], [[CST_0_]] : memref<?x4x5xf32>
// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[VAR_dim_1_]] : index to i64
// CHECK: krnl.store [[VAR_17_]], [[RES_]]{{.}}[[CST_0_]]{{.}} : memref<3xi64>
// CHECK: krnl.store [[CST_4_]], [[RES_]]{{.}}[[CST_1_]]{{.}} : memref<3xi64>
// CHECK: krnl.store [[CST_5_]], [[RES_]]{{.}}[[CST_2_]]{{.}} : memref<3xi64>
// CHECK: memref.dealloc [[VAR_2_]] : memref<?xmemref<?x4x5xf32>>
// CHECK: memref.dealloc [[VAR_16_]] : memref<?x4x5xf32>
// CHECK: return [[RES_]] : memref<3xi64>
// CHECK: }
}

0 comments on commit 6ba10ca

Please sign in to comment.