From 7220d1e13f205f82235a662ee61c6be05ca80940 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 5 Feb 2021 00:39:16 +0900 Subject: [PATCH] Use local buffer for reduction loops in MatMul, Gemm, Conv, and Pooling ops (#504) * Use local alloca for Matmul, Gemm, Conv, and Pooling Signed-off-by: Tung D. Le * Edit Matmul Signed-off-by: Tung D. Le * Edit lit test Signed-off-by: Tung D. Le * Explicitly evaluate indexexpr Signed-off-by: Tung D. Le * Comments Signed-off-by: Tung D. Le * Emit a SSA value for kernel Signed-off-by: Tung D. Le * Edit lit tests Signed-off-by: Tung D. Le --- src/Conversion/ONNXToKrnl/Math/Gemm.cpp | 60 ++++---- src/Conversion/ONNXToKrnl/Math/MatMul.cpp | 16 ++- src/Conversion/ONNXToKrnl/NN/Conv.cpp | 43 +++--- src/Conversion/ONNXToKrnl/NN/Pooling.cpp | 18 ++- test/mlir/onnx/onnx_enable_memory_pool.mlir | 8 +- test/mlir/onnx/onnx_lowering.mlir | 135 ++++++++++++------ .../onnx/onnx_lowering_with_canonicalize.mlir | 36 ++--- 7 files changed, 203 insertions(+), 113 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp index cd1202f17c29..e5814a526837 100644 --- a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp @@ -60,7 +60,10 @@ struct ONNXGemmOpLowering : public ConversionPattern { SmallVector resAccessFct({n, m}); // Insert res[n,m] = 0. - outerContext.createKrnlStoreOp(zero, alloc, resAccessFct); + // Create a local reduction value for res[n,m]. + Value reductionVal = + rewriter.create(loc, MemRefType::get({}, elementType)); + rewriter.create(loc, zero, reductionVal, ArrayRef{}); // Create the inner reduction loop. BuildKrnlLoop innerLoops(rewriter, loc, 1); @@ -68,6 +71,37 @@ struct ONNXGemmOpLowering : public ConversionPattern { innerLoops.pushBounds(0, shapeHelper.aDims[1]); innerLoops.createIterateOp(); + // Now start writing code inside the inner loop: get A & B access functions. + auto ipOuterLoopRegion = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); + { + IndexExpr k = + outerContext.createLoopInductionIndex(innerLoops.getInductionVar(0)); + SmallVector aAccessFct, bAccessFct; + if (gemmOp.transA() != 0) + aAccessFct = {k, n}; + else + aAccessFct = {n, k}; + if (gemmOp.transB() != 0) + bAccessFct = {m, k}; + else + bAccessFct = {k, m}; + // Add mat mul operation. + Value loadedA = + outerContext.createKrnlLoadOp(operandAdaptor.A(), aAccessFct); + Value loadedB = + outerContext.createKrnlLoadOp(operandAdaptor.B(), bAccessFct); + Value loadedY = + rewriter.create(loc, reductionVal, ArrayRef{}); + Value AB = rewriter.create(loc, loadedA, loadedB); + Value accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create( + loc, accumulated, reductionVal, ArrayRef{}); + } + rewriter.restoreInsertionPoint(ipOuterLoopRegion); + Value loadedAB = + rewriter.create(loc, reductionVal, ArrayRef{}); + // Write code after the completion of the inner loop. // Compute the c access function using the broadcast rules. SmallVector cAccessFct; @@ -81,7 +115,6 @@ struct ONNXGemmOpLowering : public ConversionPattern { } // Calculate reduction(AB)*alpha. - Value loadedAB = outerContext.createKrnlLoadOp(alloc, resAccessFct); Value alphaAB = rewriter.create(loc, alpha, loadedAB); if (shapeHelper.hasBias) { // Res = AB*alpha + beta * C. @@ -95,29 +128,6 @@ struct ONNXGemmOpLowering : public ConversionPattern { outerContext.createKrnlStoreOp(alphaAB, alloc, resAccessFct); } - // Now start writing code inside the inner loop: get A & B access functions. - rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); - IndexExpr k = - outerContext.createLoopInductionIndex(innerLoops.getInductionVar(0)); - SmallVector aAccessFct, bAccessFct; - if (gemmOp.transA() != 0) - aAccessFct = {k, n}; - else - aAccessFct = {n, k}; - if (gemmOp.transB() != 0) - bAccessFct = {m, k}; - else - bAccessFct = {k, m}; - // Add mat mul operation. - Value loadedA = - outerContext.createKrnlLoadOp(operandAdaptor.A(), aAccessFct); - Value loadedB = - outerContext.createKrnlLoadOp(operandAdaptor.B(), bAccessFct); - Value loadedY = outerContext.createKrnlLoadOp(alloc, resAccessFct); - Value AB = rewriter.create(loc, loadedA, loadedB); - Value accumulated = rewriter.create(loc, loadedY, AB); - outerContext.createKrnlStoreOp(accumulated, alloc, resAccessFct); - rewriter.replaceOp(op, alloc); return success(); diff --git a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp index c264861e4610..42d936cd004e 100644 --- a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp +++ b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp @@ -52,7 +52,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern { outerContext.createLoopInductionIndicesFromArrayValues( outputLoops.getAllInductionVar(), resAccessFct); // Insert res[...] = 0. - outerContext.createKrnlStoreOp(zero, alloc, resAccessFct); + // Create a local reduction value for res[...]. + Value reductionVal = + rewriter.create(loc, MemRefType::get({}, elementType)); + rewriter.create(loc, zero, reductionVal, ArrayRef{}); // Create the inner reduction loop; trip count is last dim of A. BuildKrnlLoop innerLoops(rewriter, loc, 1); @@ -63,7 +66,9 @@ struct ONNXMatMulOpLowering : public ConversionPattern { innerLoops.createIterateOp(); // Now start writing code inside the inner loop: get A & B access functions. + auto ipOuterLoopRegion = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); + IndexExpr k = outerContext.createLoopInductionIndex(innerLoops.getInductionVar(0)); SmallVector aAccessFct, bAccessFct; @@ -100,9 +105,16 @@ struct ONNXMatMulOpLowering : public ConversionPattern { outerContext.createKrnlLoadOp(operandAdaptor.A(), aAccessFct); Value loadedB = outerContext.createKrnlLoadOp(operandAdaptor.B(), bAccessFct); - Value loadedY = outerContext.createKrnlLoadOp(alloc, resAccessFct); + Value loadedY = + rewriter.create(loc, reductionVal, ArrayRef{}); Value AB = rewriter.create(loc, loadedA, loadedB); Value accumulated = rewriter.create(loc, loadedY, AB); + rewriter.create( + loc, accumulated, reductionVal, ArrayRef{}); + + rewriter.restoreInsertionPoint(ipOuterLoopRegion); + accumulated = + rewriter.create(loc, reductionVal, ArrayRef{}); outerContext.createKrnlStoreOp(accumulated, alloc, resAccessFct); // Done. diff --git a/src/Conversion/ONNXToKrnl/NN/Conv.cpp b/src/Conversion/ONNXToKrnl/NN/Conv.cpp index 78e4315cc109..0677f16d253d 100644 --- a/src/Conversion/ONNXToKrnl/NN/Conv.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Conv.cpp @@ -187,6 +187,8 @@ struct ONNXConvOpLowering : public ConversionPattern { outerLoops.getInductionVar(gIndex)); kernel = g * kernelsPerGroupValue + kernel; } + // Evaluate kernel to emit its SSA value at this location. + kernel.getValue(); // 2.2 Define spatial loops int64_t nSpatialLoops = resultShape.size() - spatialStartIndex; @@ -199,7 +201,6 @@ struct ONNXConvOpLowering : public ConversionPattern { // for rX = 0 .. RX spatialLoops.createIterateOp(); rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock()); - { // 3. Emit the body of the spatial loop nest. // 3.1 Emit: R[n][kernel][r1][r2] = 0; @@ -212,9 +213,16 @@ struct ONNXConvOpLowering : public ConversionPattern { // rX for (auto arg : spatialLoops.getIterateBlock()->getArguments()) resultIndices.emplace_back(ieContext.createLoopInductionIndex(arg)); - // Store initializer value into output location. + + // Initialize the output. ieContext.createKrnlStoreOp(zero, alloc, resultIndices); + // Create a local reduction value. + Value reductionVal = rewriter.create( + loc, MemRefType::get({}, memRefType.getElementType())); + rewriter.create( + loc, zero, reductionVal, ArrayRef{}); + // Prepare induction variables. SmallVector, 4> IVExprs; { @@ -273,19 +281,8 @@ struct ONNXConvOpLowering : public ConversionPattern { // 3.4 Emit inner loop nest. innerLoops.createIterateOp(); - // Emit the bias, if needed. - if (hasBias) { - auto loadResult = ieContext.createKrnlLoadOp(alloc, resultIndices); - SmallVector biasIndices; - biasIndices.emplace_back(kernel); - auto loadBias = ieContext.createKrnlLoadOp(biasOperand, biasIndices); - auto resultWithBias = - rewriter.create(loc, loadResult, loadBias); - // Store initializer value into output location. - ieContext.createKrnlStoreOp(resultWithBias, alloc, resultIndices); - } - // + auto ipOuterLoopRegion = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); { // 4. Emit inner loop body @@ -350,12 +347,26 @@ struct ONNXConvOpLowering : public ConversionPattern { auto loadKernel = ieContext.createKrnlLoadOp(kernelOperand, kernelIndices); auto loadPartialSum = - ieContext.createKrnlLoadOp(alloc, resultIndices); + rewriter.create(loc, reductionVal, ArrayRef{}); Value result = rewriter.create(loc, loadPartialSum, rewriter.create(loc, loadData, loadKernel)); // 4.4 Store computed value into output location. - ieContext.createKrnlStoreOp(result, alloc, resultIndices); + rewriter.create( + loc, result, reductionVal, ArrayRef{}); } + rewriter.restoreInsertionPoint(ipOuterLoopRegion); + + auto result = + rewriter.create(loc, reductionVal, ArrayRef{}); + // Store the result. Optionally add bias. + if (hasBias) { + SmallVector biasIndices; + biasIndices.emplace_back(kernel); + auto loadBias = ieContext.createKrnlLoadOp(biasOperand, biasIndices); + auto resultWithBias = rewriter.create(loc, result, loadBias); + ieContext.createKrnlStoreOp(resultWithBias, alloc, resultIndices); + } else + ieContext.createKrnlStoreOp(result, alloc, resultIndices); } } rewriter.replaceOp(op, alloc); diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index c6caf5667381..32e58a47db05 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -345,7 +345,11 @@ struct ONNXPoolOpLowering : public ConversionPattern { ieContext.createLoopInductionIndex(outputLoops.getInductionVar(i))); // 2.1 Emit: output[n][c][ho][wo] = identity - ieContext.createKrnlStoreOp(identity, alloc, outputIndices); + // Create a local reduction value for output[n][c][ho][wo]. + Value reductionVal = rewriter.create( + loc, MemRefType::get({}, memRefType.getElementType())); + rewriter.create( + loc, identity, reductionVal, ArrayRef{}); // 2.2 Emit affine maps which express the lower and upper bounds for the // pooling window's dimensions. @@ -441,7 +445,7 @@ struct ONNXPoolOpLowering : public ConversionPattern { // Create a krnl iterate. poolingLoops.createIterateOp(); - auto ipOuterLoops = rewriter.saveInsertionPoint(); + auto ipOuterLoopRegion = rewriter.saveInsertionPoint(); rewriter.setInsertionPointToStart(poolingLoops.getIterateBlock()); { // 2.4 Emit the body of the pooling loop nest. @@ -472,14 +476,18 @@ struct ONNXPoolOpLowering : public ConversionPattern { Value loadInput = ieContext.createKrnlLoadOp(inputOperand, inputIndices); Value loadPartialOutput = - ieContext.createKrnlLoadOp(alloc, outputIndices); + rewriter.create(loc, reductionVal, ArrayRef{}); Value output = emitScalarOpFor(rewriter, loc, op, outputElementType, {loadPartialOutput, loadInput}); - ieContext.createKrnlStoreOp(output, alloc, outputIndices); + rewriter.create( + loc, output, reductionVal, ArrayRef{}); } + rewriter.restoreInsertionPoint(ipOuterLoopRegion); + Value output = + rewriter.create(loc, reductionVal, ArrayRef{}); + ieContext.createKrnlStoreOp(output, alloc, outputIndices); // 2.5 Post-processing for the pooling window, e.g. taking average. - rewriter.restoreInsertionPoint(ipOuterLoops); SmallVector outputIndicesInValue; for (IndexExpr expr : outputIndices) outputIndicesInValue.emplace_back(expr.getValue()); diff --git a/test/mlir/onnx/onnx_enable_memory_pool.mlir b/test/mlir/onnx/onnx_enable_memory_pool.mlir index 3eecd333f576..f67b5eae187b 100644 --- a/test/mlir/onnx/onnx_enable_memory_pool.mlir +++ b/test/mlir/onnx/onnx_enable_memory_pool.mlir @@ -48,12 +48,15 @@ func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf3 // CHECK: krnl.store [[ADDF1]], [[GETREF1]][%arg2, %arg3] : memref<10x10xf32> // CHECK: krnl.define_loops // CHECK: krnl.iterate + // CHECK: [[REDUCTION_SUM:%.+]] = alloca() : memref // CHECK: [[LOAD3:%.+]] = krnl.load [[GETREF1]][%arg2, %arg4] : memref<10x10xf32> // CHECK: [[LOAD4:%.+]] = krnl.load %arg1[%arg4, %arg3] : memref<10x20xf32> - // CHECK: [[LOAD5:%.+]] = krnl.load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32> + // CHECK: [[LOAD5:%.+]] = krnl.load [[REDUCTION_SUM]][] : memref // CHECK: [[MULF1:%.+]] = mulf [[LOAD3]], [[LOAD4]] : f32 // CHECK: [[ADDF2:%.+]] = addf [[LOAD5]], [[MULF1]] : f32 - // CHECK: krnl.store [[ADDF2]], [[GETREF0]][%arg2, %arg3] : memref<10x20xf32> + // CHECK: krnl.store [[ADDF2]], [[REDUCTION_SUM]][] : memref + // CHECK: [[SUM:%.+]] = krnl.load [[REDUCTION_SUM]][] : memref + // CHECK: krnl.store [[SUM]], [[GETREF0]][%arg2, %arg3] : memref<10x20xf32> // CHECK: krnl.define_loops // CHECK: krnl.iterate // CHECK: [[LOAD6:%.+]] = krnl.load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32> @@ -100,7 +103,6 @@ func @test_enable_memory_pool_3(%arg0: tensor, %arg1: tensor, // CHECK: [[DATA3:%.+]] = alloc([[DIM1]]) : memref // CHECK: krnl.define_loops 2 // CHECK: krnl.iterate - // CHECK: krnl.store [[CST]], [[DATA3]][%arg3, %arg4] : memref // CHECK: krnl.define_loops 1 // CHECK: krnl.iterate // CHECK: krnl.store {{.*}}, [[DATA3]][%arg3, %arg4] : memref diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 21345e6bc5cb..8035d3ba56b8 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1065,16 +1065,19 @@ func private @test_matmul1(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>) - //CHECK: [[VAR_cst_:%.+]] = constant 0.000000e+00 : f32 //CHECK: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 //CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 10, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10) { -//CHECK: krnl.store [[VAR_cst_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +//CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +//CHECK: krnl.store [[VAR_cst_]], [[REDUCTION_VAL]][] : memref //CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 //CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 5) { //CHECK: [[LOAD_A_MEM_:%.+]] = krnl.load [[A_]]{{.}}[[I_0_]], [[I_2_]]{{.}} : memref<10x5xf32> //CHECK: [[LOAD_B_MEM_:%.+]] = krnl.load [[B_]]{{.}}[[I_2_]], [[I_1_]]{{.}} : memref<5x10xf32> -//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref //CHECK: [[VAR_6_:%.+]] = mulf [[LOAD_A_MEM_]], [[LOAD_B_MEM_]] : f32 //CHECK: [[VAR_7_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_6_]] : f32 -//CHECK: krnl.store [[VAR_7_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +//CHECK: krnl.store [[VAR_7_]], [[REDUCTION_VAL]][] : memref //CHECK: } +//CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref +//CHECK: krnl.store [[LOAD_REDUCTION]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> //CHECK: } //CHECK: return [[RES_]] : memref<10x10xf32> //CHECK: } @@ -1093,16 +1096,19 @@ func private @test_matmul2(%arg0 : tensor<10x5xf32>, %arg1 : tensor<2x3x5x10xf32 //CHECK: [[VAR_cst_:%.+]] = constant 0.000000e+00 : f32 //CHECK: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 //CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 2, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 10, [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 10) { -//CHECK: krnl.store [[VAR_cst_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]], [[I_3_]]{{.}} : memref<2x3x10x10xf32> +//CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +//CHECK: krnl.store [[VAR_cst_]], [[REDUCTION_VAL]][] : memref //CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 //CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_4_:%.+]] = 0 to 5) { //CHECK: [[LOAD_A_MEM_:%.+]] = krnl.load [[A_]]{{.}}[[I_2_]], [[I_4_]]{{.}} : memref<10x5xf32> //CHECK: [[LOAD_B_MEM_:%.+]] = krnl.load [[B_]]{{.}}[[I_0_]], [[I_1_]], [[I_4_]], [[I_3_]]{{.}} : memref<2x3x5x10xf32> -//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]], [[I_3_]]{{.}} : memref<2x3x10x10xf32> +//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref //CHECK: [[VAR_6_:%.+]] = mulf [[LOAD_A_MEM_]], [[LOAD_B_MEM_]] : f32 //CHECK: [[VAR_7_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_6_]] : f32 -//CHECK: krnl.store [[VAR_7_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]], [[I_3_]]{{.}} : memref<2x3x10x10xf32> +//CHECK: krnl.store [[VAR_7_]], [[REDUCTION_VAL]][] : memref //CHECK: } +//CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref +//CHECK: krnl.store [[LOAD_REDUCTION]], [[RES_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]], [[I_3_]]{{.}} : memref<2x3x10x10xf32> //CHECK: } //CHECK: return [[RES_]] : memref<2x3x10x10xf32> //CHECK: } @@ -1121,16 +1127,19 @@ func private @test_matmul3(%arg0 : tensor<2x3x10x5xf32>, %arg1 : tensor<2x3x5x10 //CHECK: [[VAR_cst_:%.+]] = constant 0.000000e+00 : f32 //CHECK: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 //CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 2, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 3, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 10, [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 10) { -//CHECK: krnl.store [[VAR_cst_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]], [[I_3_]]{{.}} : memref<2x3x10x10xf32> +//CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +//CHECK: krnl.store [[VAR_cst_]], [[REDUCTION_VAL]][] : memref //CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 //CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_4_:%.+]] = 0 to 5) { //CHECK: [[LOAD_A_MEM_:%.+]] = krnl.load [[A_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]], [[I_4_]]{{.}} : memref<2x3x10x5xf32> //CHECK: [[LOAD_B_MEM_:%.+]] = krnl.load [[B_]]{{.}}[[I_0_]], [[I_1_]], [[I_4_]], [[I_3_]]{{.}} : memref<2x3x5x10xf32> -//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]], [[I_3_]]{{.}} : memref<2x3x10x10xf32> +//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref //CHECK: [[VAR_6_:%.+]] = mulf [[LOAD_A_MEM_]], [[LOAD_B_MEM_]] : f32 //CHECK: [[VAR_7_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_6_]] : f32 -//CHECK: krnl.store [[VAR_7_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]], [[I_3_]]{{.}} : memref<2x3x10x10xf32> +//CHECK: krnl.store [[VAR_7_]], [[REDUCTION_VAL]][] : memref //CHECK: } +//CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref +//CHECK: krnl.store [[LOAD_REDUCTION]], [[RES_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]], [[I_3_]]{{.}} : memref<2x3x10x10xf32> //CHECK: } //CHECK: return [[RES_]] : memref<2x3x10x10xf32> //CHECK: } @@ -1149,16 +1158,19 @@ func private @test_matmul4(%arg0 : tensor<5xf32>, %arg1 : tensor<5x10xf32>) -> t //CHECK: [[VAR_cst_:%.+]] = constant 0.000000e+00 : f32 //CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 //CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 10) { -//CHECK: krnl.store [[VAR_cst_]], [[RES_]]{{.}}[[I_0_]]{{.}} : memref<10xf32> +//CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +//CHECK: krnl.store [[VAR_cst_]], [[REDUCTION_VAL]][] : memref //CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 //CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 5) { //CHECK: [[LOAD_A_MEM_:%.+]] = krnl.load [[A_]]{{.}}[[I_1_]]{{.}} : memref<5xf32> //CHECK: [[LOAD_B_MEM_:%.+]] = krnl.load [[B_]]{{.}}[[I_1_]], [[I_0_]]{{.}} : memref<5x10xf32> -//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]]{{.}} : memref<10xf32> +//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref //CHECK: [[VAR_6_:%.+]] = mulf [[LOAD_A_MEM_]], [[LOAD_B_MEM_]] : f32 //CHECK: [[VAR_7_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_6_]] : f32 -//CHECK: krnl.store [[VAR_7_]], [[RES_]]{{.}}[[I_0_]]{{.}} : memref<10xf32> +//CHECK: krnl.store [[VAR_7_]], [[REDUCTION_VAL]][] : memref //CHECK: } +//CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref +//CHECK: krnl.store [[LOAD_REDUCTION]], [[RES_]]{{.}}[[I_0_]]{{.}} : memref<10xf32> //CHECK: } //CHECK: return [[RES_]] : memref<10xf32> //CHECK: } @@ -1179,16 +1191,19 @@ func private @test_matmul5(%arg0 : tensor<5xf32>, %arg1 : tensor) -> //CHECK: [[VAR_cst_:%.+]] = constant 0.000000e+00 : f32 //CHECK: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 //CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10) { -//CHECK: krnl.store [[VAR_cst_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref +//CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +//CHECK: krnl.store [[VAR_cst_]], [[REDUCTION_VAL]][] : memref //CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 //CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 5) { //CHECK: [[LOAD_A_MEM_:%.+]] = krnl.load [[A_]]{{.}}[[I_2_]]{{.}} : memref<5xf32> //CHECK: [[LOAD_B_MEM_:%.+]] = krnl.load [[B_]]{{.}}[[I_0_]], [[I_2_]], [[I_1_]]{{.}} : memref -//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref +//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref //CHECK: [[VAR_7_:%.+]] = mulf [[LOAD_A_MEM_]], [[LOAD_B_MEM_]] : f32 //CHECK: [[VAR_8_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_7_]] : f32 -//CHECK: krnl.store [[VAR_8_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref +//CHECK: krnl.store [[VAR_8_]], [[REDUCTION_VAL]][] : memref //CHECK: } +//CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref +//CHECK: krnl.store [[LOAD_REDUCTION]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref //CHECK: } //CHECK: return [[RES_]] : memref //CHECK: } @@ -1209,16 +1224,19 @@ func private @test_matmul6(%arg0 : tensor, %arg1 : tensor<5xf32>) -> //CHECK: [[VAR_cst_:%.+]] = constant 0.000000e+00 : f32 //CHECK: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 //CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10) { -//CHECK: krnl.store [[VAR_cst_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref +//CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +//CHECK: krnl.store [[VAR_cst_]], [[REDUCTION_VAL]][] : memref //CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 //CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 5) { //CHECK: [[LOAD_A_MEM_:%.+]] = krnl.load [[A_]]{{.}}[[I_0_]], [[I_1_]], [[I_2_]]{{.}} : memref //CHECK: [[LOAD_B_MEM_:%.+]] = krnl.load [[B_]]{{.}}[[I_2_]]{{.}} : memref<5xf32> -//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref +//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref //CHECK: [[VAR_7_:%.+]] = mulf [[LOAD_A_MEM_]], [[LOAD_B_MEM_]] : f32 //CHECK: [[VAR_8_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_7_]] : f32 -//CHECK: krnl.store [[VAR_8_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref +//CHECK: krnl.store [[VAR_8_]], [[REDUCTION_VAL]][] : memref //CHECK: } +//CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref +//CHECK: krnl.store [[LOAD_REDUCTION]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref //CHECK: } //CHECK: return [[RES_]] : memref //CHECK: } @@ -1237,16 +1255,19 @@ func private @test_matmul7(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tens //CHECK: [[VAR_cst_:%.+]] = constant 0.000000e+00 : f32 //CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 //CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 1) { -//CHECK: krnl.store [[VAR_cst_]], [[RES_]]{{.}}[[I_0_]]{{.}} : memref<1xf32> +//CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +//CHECK: krnl.store [[VAR_cst_]], [[REDUCTION_VAL]][] : memref //CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 //CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 5) { //CHECK: [[LOAD_A_MEM_:%.+]] = krnl.load [[A_]]{{.}}[[I_1_]]{{.}} : memref<5xf32> //CHECK: [[LOAD_B_MEM_:%.+]] = krnl.load [[B_]]{{.}}[[I_1_]]{{.}} : memref<5xf32> -//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]]{{.}} : memref<1xf32> +//CHECK: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref //CHECK: [[VAR_6_:%.+]] = mulf [[LOAD_A_MEM_]], [[LOAD_B_MEM_]] : f32 //CHECK: [[VAR_7_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_6_]] : f32 -//CHECK: krnl.store [[VAR_7_]], [[RES_]]{{.}}[[I_0_]]{{.}} : memref<1xf32> +//CHECK: krnl.store [[VAR_7_]], [[REDUCTION_VAL]][] : memref //CHECK: } +//CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref +//CHECK: krnl.store [[LOAD_REDUCTION]], [[RES_]]{{.}}[[I_0_]]{{.}} : memref<1xf32> //CHECK: } //CHECK: return [[RES_]] : memref<1xf32> //CHECK: } @@ -1274,7 +1295,8 @@ func private @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : te // CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg4 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg5 = 0 to 58) { - // CHECK: krnl.store [[CONST1]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x27x58xf32> + // CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref + // CHECK: krnl.store [[CONST1]], [[REDUCTION_VAL]][] : memref // CHECK: [[START1:%.+]] = affine.max #[[ZERO_MAP2]](%arg4, %arg4) // CHECK: {{.*}} = affine.min {{.*}} // CHECK: [[KERNEL_OFFSET1:%.+]] = affine.min #[[ZERO_MAP2]](%arg4, %arg4) @@ -1291,11 +1313,13 @@ func private @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : te // CHECK: [[K2:%.+]] = subi %arg8, [[KERNEL_OFFSET2]] : index // CHECK: [[DATA:%.+]] = krnl.load %arg0[%arg2, %arg6, [[R1]], [[R2]]{{\]}} : memref<1x2x32x64xf32> // CHECK: [[KERNEL:%.+]] = krnl.load %arg1[%arg3, %arg6, [[K1]], [[K2]]{{\]}} : memref<5x2x6x7xf32> - // CHECK: [[ACC_RES:%.+]] = krnl.load %0[%arg2, %arg3, %arg4, %arg5] : memref<1x5x27x58xf32> + // CHECK: [[ACC_RES:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32 // CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32 - // CHECK: krnl.store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x27x58xf32> + // CHECK: krnl.store [[ADD]], [[REDUCTION_VAL]][] : memref // CHECK: } + // CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref + // CHECK: krnl.store [[LOAD_REDUCTION]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x27x58xf32> // CHECK: } // CHECK: } @@ -1317,12 +1341,13 @@ func private @test_conv_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tenso // CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) { - // CHECK: krnl.store [[CONST1]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref + // CHECK: krnl.store [[CONST1]], [[REDUCTION_VAL]][] : memref // CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3 // CHECK: krnl.iterate([[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg7 = 0 to 2, [[INNER_LOOPS]]#1 -> %arg8 = 0 to min #{{.*}}(%arg5)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}], [[INNER_LOOPS]]#2 -> %arg9 = 0 to min #{{.*}}(%arg6)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}]) { // CHECK: } - // CHECK: [[BIAS1:%.+]] = krnl.load [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: [[BIAS1:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[BIAS2:%.+]] = krnl.load %arg2[%arg4] : memref<5xf32> // CHECK: [[BIAS3:%.+]] = addf [[BIAS1]], [[BIAS2]] : f32 // CHECK: krnl.store [[BIAS3]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> @@ -1378,7 +1403,8 @@ func private @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, // CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg4 = 0 to 14, [[SPATIAL_LOOPS]]#1 -> %arg5 = 0 to 29) { - // CHECK: krnl.store [[CONST1]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32> + // CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref + // CHECK: krnl.store [[CONST1]], [[REDUCTION_VAL]][] : memref // CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3 // CHECK: krnl.iterate([[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg6 = 0 to 9, [[INNER_LOOPS]]#1 -> %arg7 = 0 to min #[[BOUND]](%arg4)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}], [[INNER_LOOPS]]#2 -> %arg8 = 0 to min #[[BOUND]](%arg5)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}]) { @@ -1388,11 +1414,13 @@ func private @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, // CHECK: [[K2:%.+]] = subi %arg8, [[KERNEL_OFFSET2]] : index // CHECK: [[DATA:%.+]] = krnl.load %arg0[%arg2, %arg6, [[R1]], [[R2]]{{\]}} : memref<1x9x32x64xf32> // CHECK: [[KERNEL:%.+]] = krnl.load %arg1[%arg3, %arg6, [[K1]], [[K2]]{{\]}} : memref<5x9x6x7xf32> - // CHECK: [[ACC_RES:%.+]] = krnl.load %0[%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32> + // CHECK: [[ACC_RES:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32 // CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32 - // CHECK: krnl.store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32> + // CHECK: krnl.store [[ADD]], [[REDUCTION_VAL]][] : memref // CHECK: } + // CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref + // CHECK: krnl.store [[LOAD_REDUCTION]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x14x29xf32> // CHECK: } // CHECK: } @@ -1422,11 +1450,13 @@ func private @test_conv_bias_group_pad_stride_dilation(%arg0 : tensor<1x9x32x64x // CHECK: [[OUTER_LOOPS:%.+]]:3 = krnl.define_loops 3 // CHECK: krnl.iterate([[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1, [[OUTER_LOOPS]]#2) with ([[OUTER_LOOPS]]#0 -> %arg3 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg4 = 0 to 3, [[OUTER_LOOPS]]#2 -> %arg5 = 0 to 1) { + // CHECK: [[MAP0_APPLY:%.+]] = affine.apply #[[MAP0]](%arg4, %arg5) // CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg6 = 0 to 13, [[SPATIAL_LOOPS]]#1 -> %arg7 = 0 to 28) { - // CHECK: [[MAP0_APPLY:%.+]] = affine.apply #[[MAP0]](%arg4, %arg5) // CHECK: krnl.store [[INITIALIZE_VALUE]], [[RES]][%arg3, [[MAP0_APPLY]], %arg6, %arg7] : memref<1x5x13x28xf32> + // CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref + // CHECK: krnl.store [[INITIALIZE_VALUE]], [[REDUCTION_VAL]][] : memref // CHECK: [[START1:%.+]] = affine.max #[[MAP1]](%arg4, %arg5, %arg6, %arg6) // CHECK: {{.*}} = affine.min {{.*}} // CHECK: [[OFFSET1:%.+]] = affine.min #[[MAP1]](%arg4, %arg5, %arg6, %arg6) @@ -1456,14 +1486,14 @@ func private @test_conv_bias_group_pad_stride_dilation(%arg0 : tensor<1x9x32x64x // CHECK: [[DATA:%.+]] = krnl.load %arg0[%arg3, [[GROUP]], [[R1]], [[R2]]] : memref<1x9x32x64xf32> // CHECK: [[KERNEL:%.+]] = krnl.load %arg1{{\[}}[[MAP0_APPLY]], %arg8, [[K1]], [[K2]]{{\]}} : memref<5x3x6x7xf32> - // CHECK: [[ACC_RES:%.+]] = krnl.load [[RES]][%arg3, [[MAP0_APPLY]], %arg6, %arg7] : memref<1x5x13x28xf32> + // CHECK: [[ACC_RES:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32 // CHECK: [[ADD1:%.+]] = addf [[ACC_RES]], [[MUL]] : f32 - // CHECK: krnl.store [[ADD1]], [[RES]][%arg3, [[MAP0_APPLY]], %arg6, %arg7] : memref<1x5x13x28xf32> + // CHECK: krnl.store [[ADD1]], [[REDUCTION_VAL]][] : memref // CHECK: } - // CHECK: [[BIAS1:%.+]] = krnl.load [[RES]][%arg3, [[MAP0_APPLY]], %arg6, %arg7] : memref<1x5x13x28xf32> + // CHECK: [[BIAS1:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[BIAS2:%.+]] = krnl.load %arg2{{\[}}[[MAP0_APPLY]]{{\]}} : memref<5xf32> - // CHECK: [[BIAS3:%.+]] = addf %11, %12 : f32 + // CHECK: [[BIAS3:%.+]] = addf [[BIAS1]], [[BIAS2]] : f32 // CHECK: krnl.store [[BIAS3]], [[RES]][%arg3, [[MAP0_APPLY]], %arg6, %arg7] : memref<1x5x13x28xf32> // CHECK: } // CHECK: } @@ -1701,17 +1731,18 @@ func private @test_pool_general_computation(%arg0 : tensor<1x3x32x32xf32>) -> te // CHECK: [[OUTPUT_LOOPS:%.+]]:4 = krnl.define_loops 4 // CHECK: krnl.iterate([[OUTPUT_LOOPS]]#0, [[OUTPUT_LOOPS]]#1, [[OUTPUT_LOOPS]]#2, [[OUTPUT_LOOPS]]#3) with ([[OUTPUT_LOOPS]]#0 -> %arg1 = 0 to 1, [[OUTPUT_LOOPS]]#1 -> %arg2 = 0 to 3, [[OUTPUT_LOOPS]]#2 -> %arg3 = 0 to 31, [[OUTPUT_LOOPS]]#3 -> %arg4 = 0 to 31) { - // CHECK: krnl.store [[IDENTITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref + // CHECK: krnl.store [[IDENTITY]], [[REDUCTION_VAL]][] : memref // CHECK: [[POOL_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[POOL_LOOPS]]#0, [[POOL_LOOPS]]#1) with ([[POOL_LOOPS]]#0 -> %arg5 = 0 to min #[[BOUND]](%arg3)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}], [[POOL_LOOPS]]#1 -> %arg6 = 0 to min #[[BOUND]](%arg4)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}]) { // CHECK: {{.*}} = krnl.load %arg0[%arg1, %arg2, {{.*}}, {{.*}}] : memref<1x3x32x32xf32> - // CHECK: {{.*}} = krnl.load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> - // CHECK: krnl.store {{.*}}, [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: {{.*}} = krnl.load [[REDUCTION_VAL]][] : memref + // CHECK: krnl.store {{.*}}, [[REDUCTION_VAL]][] : memref // CHECK: } - // CHECK: {{.*}} = krnl.load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> - // CHECK: krnl.store {{.*}}, [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref + // CHECK: krnl.store [[LOAD_REDUCTION]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> // CHECK: } } @@ -1742,7 +1773,8 @@ func private @test_averagepool_identity_value(%arg0 : tensor<1x3x32x32xf32>) -> // CHECK-LABEL: @test_averagepool_identity_value // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> // CHECK: [[IDENTITY:%.+]] = constant 0.000000e+00 : f32 - // CHECK: krnl.store [[IDENTITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref + // CHECK: krnl.store [[IDENTITY]], [[REDUCTION_VAL]][] : memref } // ----- @@ -1754,7 +1786,8 @@ func private @test_maxpool_identity_value(%arg0 : tensor<1x3x32x32xf32>) -> tens // CHECK-LABEL: @test_maxpool_identity_value // CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32> // CHECK: [[IDENTITY:%.+]] = constant 0xFF800000 : f32 - // CHECK: krnl.store [[IDENTITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref + // CHECK: krnl.store [[IDENTITY]], [[REDUCTION_VAL]][] : memref } // ----- @@ -1769,14 +1802,19 @@ func private @test_averagepool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) // CHECK: [[OUTPUT_LOOPS:%.+]]:4 = krnl.define_loops 4 // CHECK: krnl.iterate([[OUTPUT_LOOPS]]#0, [[OUTPUT_LOOPS]]#1, [[OUTPUT_LOOPS]]#2, [[OUTPUT_LOOPS]]#3) with ([[OUTPUT_LOOPS]]#0 -> %arg1 = 0 to 1, [[OUTPUT_LOOPS]]#1 -> %arg2 = 0 to 3, [[OUTPUT_LOOPS]]#2 -> %arg3 = 0 to 31, [[OUTPUT_LOOPS]]#3 -> %arg4 = 0 to 31) { + // CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref + // CHECK: krnl.store {{.*}}, [[REDUCTION_VAL]][] : memref + // CHECK: [[POOL_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[POOL_LOOPS]]#0, [[POOL_LOOPS]]#1) with ([[POOL_LOOPS]]#0 -> %arg5 = 0 to min #{{.*}}(%arg3)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}], [[POOL_LOOPS]]#1 -> %arg6 = 0 to min #{{.*}}(%arg4)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}]) { // CHECK: [[INPUT_LOAD:%.+]] = krnl.load %arg0[%arg1, %arg2, {{.*}}, {{.*}}] : memref<1x3x32x32xf32> - // CHECK: [[OUTPUT_LOAD:%.+]] = krnl.load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[OUTPUT_LOAD:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[SUM:%.+]] = addf [[OUTPUT_LOAD]], [[INPUT_LOAD]] : f32 - // CHECK: krnl.store [[SUM]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: krnl.store [[SUM]], [[REDUCTION_VAL]][] : memref // CHECK: } + // CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref + // CHECK: krnl.store [[LOAD_REDUCTION]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> // CHECK: [[NUMERATOR:%.+]] = krnl.load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> // CHECK: [[AVERAGE:%.+]] = divf [[NUMERATOR]], {{.*}} : f32 @@ -1796,15 +1834,20 @@ func private @test_maxpool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) -> t // CHECK: [[OUTPUT_LOOPS:%.+]]:4 = krnl.define_loops 4 // CHECK: krnl.iterate([[OUTPUT_LOOPS]]#0, [[OUTPUT_LOOPS]]#1, [[OUTPUT_LOOPS]]#2, [[OUTPUT_LOOPS]]#3) with ([[OUTPUT_LOOPS]]#0 -> %arg1 = 0 to 1, [[OUTPUT_LOOPS]]#1 -> %arg2 = 0 to 3, [[OUTPUT_LOOPS]]#2 -> %arg3 = 0 to 31, [[OUTPUT_LOOPS]]#3 -> %arg4 = 0 to 31) { + // CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref + // CHECK: krnl.store {{.*}}, [[REDUCTION_VAL]][] : memref + // CHECK: [[POOL_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[POOL_LOOPS]]#0, [[POOL_LOOPS]]#1) with ([[POOL_LOOPS]]#0 -> %arg5 = 0 to min #{{.*}}(%arg3)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}], [[POOL_LOOPS]]#1 -> %arg6 = 0 to min #{{.*}}(%arg4)[{{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}]) { // CHECK: [[INPUT_LOAD:%.+]] = krnl.load %arg0[%arg1, %arg2, {{.*}}, {{.*}}] : memref<1x3x32x32xf32> - // CHECK: [[OUTPUT_LOAD:%.+]] = krnl.load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: [[OUTPUT_LOAD:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[GREATER:%.+]] = cmpf "ogt", [[OUTPUT_LOAD]], [[INPUT_LOAD]] : f32 // CHECK: [[SELECT:%.+]] = select [[GREATER]], [[OUTPUT_LOAD]], [[INPUT_LOAD]] : f32 - // CHECK: krnl.store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> + // CHECK: krnl.store [[SELECT]], [[REDUCTION_VAL]][] : memref // CHECK: } + // CHECK: [[LOAD_REDUCTION:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref + // CHECK: krnl.store [[LOAD_REDUCTION]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> // CHECK-NOT: {{.*}} = krnl.load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> // CHECK-NOT: krnl.store {{.*}}, [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32> @@ -3423,4 +3466,4 @@ func private @test_loop_simple_main_graph(%arg0: tensor, %arg1: tensor, // CHECK: return [[Y]] : memref<1xi64> // CHECK: } // CHECK: } -} \ No newline at end of file +} diff --git a/test/mlir/onnx/onnx_lowering_with_canonicalize.mlir b/test/mlir/onnx/onnx_lowering_with_canonicalize.mlir index ef8a4657337f..4b605580e749 100644 --- a/test/mlir/onnx/onnx_lowering_with_canonicalize.mlir +++ b/test/mlir/onnx/onnx_lowering_with_canonicalize.mlir @@ -352,17 +352,18 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso // CHECK-DAG: [[RES_:%.+]] = alloc() : memref<10x10xf32> // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 10, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10) { -// CHECK: krnl.store [[CST_0_dot_000000_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +// CHECK: krnl.store [[CST_0_dot_000000_]], [[REDUCTION_VAL]][] : memref // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 5) { // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[I_2_]], [[I_0_]]{{.}} : memref<5x10xf32> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[I_2_]], [[I_1_]]{{.}} : memref<5x10xf32> -// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[VAR_11_:%.+]] = mulf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 // CHECK: [[VAR_12_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_11_]] : f32 -// CHECK: krnl.store [[VAR_12_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK: krnl.store [[VAR_12_]], [[REDUCTION_VAL]][] : memref // CHECK: } -// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK-DAG: [[VAR_4_:%.+]] = mulf [[CST_1_dot_000000_]], [[LOAD_RES_MEM_1_]] : f32 // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[I_1_]]{{.}} : memref<10xf32> // CHECK: [[VAR_6_:%.+]] = mulf [[CST_5_dot_000000_]], [[LOAD_PARAM_2_MEM_]] : f32 @@ -396,19 +397,20 @@ func @test_gemm_all_dyn(%arg0 : tensor, %arg1 : tensor, %arg2: // CHECK-DAG: [[RES_:%.+]] = alloc([[DIM_0_]], [[DIM_2_]]) : memref // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[DIM_0_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[DIM_2_]]) { -// CHECK: krnl.store [[CST_0_dot_000000_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref +// CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +// CHECK: krnl.store [[CST_0_dot_000000_]], [[REDUCTION_VAL]][] : memref // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to [[DIM_1_]]) { // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[I_2_]], [[I_0_]]{{.}} : memref // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[I_2_]], [[I_1_]]{{.}} : memref -// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[VAR_17_:%.+]] = mulf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 // CHECK: [[VAR_18_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_17_]] : f32 -// CHECK: krnl.store [[VAR_18_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref +// CHECK: krnl.store [[VAR_18_]], [[REDUCTION_VAL]][] : memref // CHECK: } +// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[VAR_7_:%.+]] = cmpi "sgt", [[DIM_3_]], [[CST_1_]] : index // CHECK-DAG: [[VAR_8_:%.+]] = select [[VAR_7_]], [[I_1_]], [[CST_0_]] : index -// CHECK-DAG: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_10_:%.+]] = mulf [[CST_1_dot_000000_]], [[LOAD_RES_MEM_1_]] : f32 // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_8_]]{{.}} : memref @@ -438,17 +440,18 @@ func @test_gemm_k_dyn(%arg0 : tensor, %arg1 : tensor, %arg2: // CHECK-DAG: [[DIM_0_:%.+]] = dim [[PARAM_0_]], [[CST_0_]] : memref // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 10, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10) { -// CHECK: krnl.store [[CST_0_dot_000000_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +// CHECK: krnl.store [[CST_0_dot_000000_]], [[REDUCTION_VAL]][] : memref // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to [[DIM_0_]]) { // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[I_2_]], [[I_0_]]{{.}} : memref // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[I_2_]], [[I_1_]]{{.}} : memref -// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[VAR_12_:%.+]] = mulf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 // CHECK: [[VAR_13_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_12_]] : f32 -// CHECK: krnl.store [[VAR_13_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK: krnl.store [[VAR_13_]], [[REDUCTION_VAL]][] : memref // CHECK: } -// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK-DAG: [[VAR_5_:%.+]] = mulf [[CST_1_dot_000000_]], [[LOAD_RES_MEM_1_]] : f32 // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[I_1_]]{{.}} : memref<10xf32> // CHECK: [[VAR_7_:%.+]] = mulf [[CST_5_dot_000000_]], [[LOAD_PARAM_2_MEM_]] : f32 @@ -478,19 +481,20 @@ func @test_gemm_c_dyn(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: // CHECK-DAG: [[DIM_0_:%.+]] = dim [[PARAM_2_]], [[CST_0_]] : memref // CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 10, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10) { -// CHECK: krnl.store [[CST_0_dot_000000_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK: [[REDUCTION_VAL:%.+]] = alloca() : memref +// CHECK: krnl.store [[CST_0_dot_000000_]], [[REDUCTION_VAL]][] : memref // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = 0 to 5) { // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[I_2_]], [[I_0_]]{{.}} : memref<5x10xf32> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[I_2_]], [[I_1_]]{{.}} : memref<5x10xf32> -// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[VAR_14_:%.+]] = mulf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 // CHECK: [[VAR_15_:%.+]] = addf [[LOAD_RES_MEM_]], [[VAR_14_]] : f32 -// CHECK: krnl.store [[VAR_15_]], [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> +// CHECK: krnl.store [[VAR_15_]], [[REDUCTION_VAL]][] : memref // CHECK: } +// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[REDUCTION_VAL]][] : memref // CHECK: [[VAR_4_:%.+]] = cmpi "sgt", [[DIM_0_]], [[CST_1_]] : index // CHECK-DAG: [[VAR_5_:%.+]] = select [[VAR_4_]], [[I_1_]], [[CST_0_]] : index -// CHECK-DAG: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : memref<10x10xf32> // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_7_:%.+]] = mulf [[CST_1_dot_000000_]], [[LOAD_RES_MEM_1_]] : f32 // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[VAR_5_]]{{.}} : memref