Skip to content

Commit

Permalink
Use local buffer for reduction loops in MatMul, Gemm, Conv, and Pooli…
Browse files Browse the repository at this point in the history
…ng ops (llvm#504)

* Use local alloca for Matmul, Gemm, Conv, and Pooling

Signed-off-by: Tung D. Le <[email protected]>

* Edit Matmul

Signed-off-by: Tung D. Le <[email protected]>

* Edit lit test

Signed-off-by: Tung D. Le <[email protected]>

* Explicitly evaluate indexexpr

Signed-off-by: Tung D. Le <[email protected]>

* Comments

Signed-off-by: Tung D. Le <[email protected]>

* Emit a SSA value for kernel

Signed-off-by: Tung D. Le <[email protected]>

* Edit lit tests

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Feb 4, 2021
1 parent 3c1dee4 commit 7220d1e
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 113 deletions.
60 changes: 35 additions & 25 deletions src/Conversion/ONNXToKrnl/Math/Gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,48 @@ struct ONNXGemmOpLowering : public ConversionPattern {
SmallVector<IndexExpr, 4> resAccessFct({n, m});

// Insert res[n,m] = 0.
outerContext.createKrnlStoreOp(zero, alloc, resAccessFct);
// Create a local reduction value for res[n,m].
Value reductionVal =
rewriter.create<AllocaOp>(loc, MemRefType::get({}, elementType));
rewriter.create<KrnlStoreOp>(loc, zero, reductionVal, ArrayRef<Value>{});

// Create the inner reduction loop.
BuildKrnlLoop innerLoops(rewriter, loc, 1);
innerLoops.createDefineOp();
innerLoops.pushBounds(0, shapeHelper.aDims[1]);
innerLoops.createIterateOp();

// Now start writing code inside the inner loop: get A & B access functions.
auto ipOuterLoopRegion = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
{
IndexExpr k =
outerContext.createLoopInductionIndex(innerLoops.getInductionVar(0));
SmallVector<IndexExpr, 4> aAccessFct, bAccessFct;
if (gemmOp.transA() != 0)
aAccessFct = {k, n};
else
aAccessFct = {n, k};
if (gemmOp.transB() != 0)
bAccessFct = {m, k};
else
bAccessFct = {k, m};
// Add mat mul operation.
Value loadedA =
outerContext.createKrnlLoadOp(operandAdaptor.A(), aAccessFct);
Value loadedB =
outerContext.createKrnlLoadOp(operandAdaptor.B(), bAccessFct);
Value loadedY =
rewriter.create<KrnlLoadOp>(loc, reductionVal, ArrayRef<Value>{});
Value AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
Value accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
rewriter.create<KrnlStoreOp>(
loc, accumulated, reductionVal, ArrayRef<Value>{});
}
rewriter.restoreInsertionPoint(ipOuterLoopRegion);
Value loadedAB =
rewriter.create<KrnlLoadOp>(loc, reductionVal, ArrayRef<Value>{});

// Write code after the completion of the inner loop.
// Compute the c access function using the broadcast rules.
SmallVector<IndexExpr, 4> cAccessFct;
Expand All @@ -81,7 +115,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
}

// Calculate reduction(AB)*alpha.
Value loadedAB = outerContext.createKrnlLoadOp(alloc, resAccessFct);
Value alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
if (shapeHelper.hasBias) {
// Res = AB*alpha + beta * C.
Expand All @@ -95,29 +128,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
outerContext.createKrnlStoreOp(alphaAB, alloc, resAccessFct);
}

// Now start writing code inside the inner loop: get A & B access functions.
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
IndexExpr k =
outerContext.createLoopInductionIndex(innerLoops.getInductionVar(0));
SmallVector<IndexExpr, 4> aAccessFct, bAccessFct;
if (gemmOp.transA() != 0)
aAccessFct = {k, n};
else
aAccessFct = {n, k};
if (gemmOp.transB() != 0)
bAccessFct = {m, k};
else
bAccessFct = {k, m};
// Add mat mul operation.
Value loadedA =
outerContext.createKrnlLoadOp(operandAdaptor.A(), aAccessFct);
Value loadedB =
outerContext.createKrnlLoadOp(operandAdaptor.B(), bAccessFct);
Value loadedY = outerContext.createKrnlLoadOp(alloc, resAccessFct);
Value AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
Value accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
outerContext.createKrnlStoreOp(accumulated, alloc, resAccessFct);

rewriter.replaceOp(op, alloc);

return success();
Expand Down
16 changes: 14 additions & 2 deletions src/Conversion/ONNXToKrnl/Math/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
outerContext.createLoopInductionIndicesFromArrayValues(
outputLoops.getAllInductionVar(), resAccessFct);
// Insert res[...] = 0.
outerContext.createKrnlStoreOp(zero, alloc, resAccessFct);
// Create a local reduction value for res[...].
Value reductionVal =
rewriter.create<AllocaOp>(loc, MemRefType::get({}, elementType));
rewriter.create<KrnlStoreOp>(loc, zero, reductionVal, ArrayRef<Value>{});

// Create the inner reduction loop; trip count is last dim of A.
BuildKrnlLoop innerLoops(rewriter, loc, 1);
Expand All @@ -63,7 +66,9 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
innerLoops.createIterateOp();

// Now start writing code inside the inner loop: get A & B access functions.
auto ipOuterLoopRegion = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());

IndexExpr k =
outerContext.createLoopInductionIndex(innerLoops.getInductionVar(0));
SmallVector<IndexExpr, 4> aAccessFct, bAccessFct;
Expand Down Expand Up @@ -100,9 +105,16 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
outerContext.createKrnlLoadOp(operandAdaptor.A(), aAccessFct);
Value loadedB =
outerContext.createKrnlLoadOp(operandAdaptor.B(), bAccessFct);
Value loadedY = outerContext.createKrnlLoadOp(alloc, resAccessFct);
Value loadedY =
rewriter.create<KrnlLoadOp>(loc, reductionVal, ArrayRef<Value>{});
Value AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
Value accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
rewriter.create<KrnlStoreOp>(
loc, accumulated, reductionVal, ArrayRef<Value>{});

rewriter.restoreInsertionPoint(ipOuterLoopRegion);
accumulated =
rewriter.create<KrnlLoadOp>(loc, reductionVal, ArrayRef<Value>{});
outerContext.createKrnlStoreOp(accumulated, alloc, resAccessFct);

// Done.
Expand Down
43 changes: 27 additions & 16 deletions src/Conversion/ONNXToKrnl/NN/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ struct ONNXConvOpLowering : public ConversionPattern {
outerLoops.getInductionVar(gIndex));
kernel = g * kernelsPerGroupValue + kernel;
}
// Evaluate kernel to emit its SSA value at this location.
kernel.getValue();

// 2.2 Define spatial loops
int64_t nSpatialLoops = resultShape.size() - spatialStartIndex;
Expand All @@ -199,7 +201,6 @@ struct ONNXConvOpLowering : public ConversionPattern {
// for rX = 0 .. RX
spatialLoops.createIterateOp();
rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock());

{
// 3. Emit the body of the spatial loop nest.
// 3.1 Emit: R[n][kernel][r1][r2] = 0;
Expand All @@ -212,9 +213,16 @@ struct ONNXConvOpLowering : public ConversionPattern {
// rX
for (auto arg : spatialLoops.getIterateBlock()->getArguments())
resultIndices.emplace_back(ieContext.createLoopInductionIndex(arg));
// Store initializer value into output location.

// Initialize the output.
ieContext.createKrnlStoreOp(zero, alloc, resultIndices);

// Create a local reduction value.
Value reductionVal = rewriter.create<AllocaOp>(
loc, MemRefType::get({}, memRefType.getElementType()));
rewriter.create<KrnlStoreOp>(
loc, zero, reductionVal, ArrayRef<Value>{});

// Prepare induction variables.
SmallVector<SmallVector<IndexExpr, 4>, 4> IVExprs;
{
Expand Down Expand Up @@ -273,19 +281,8 @@ struct ONNXConvOpLowering : public ConversionPattern {
// 3.4 Emit inner loop nest.
innerLoops.createIterateOp();

// Emit the bias, if needed.
if (hasBias) {
auto loadResult = ieContext.createKrnlLoadOp(alloc, resultIndices);
SmallVector<IndexExpr, 4> biasIndices;
biasIndices.emplace_back(kernel);
auto loadBias = ieContext.createKrnlLoadOp(biasOperand, biasIndices);
auto resultWithBias =
rewriter.create<AddFOp>(loc, loadResult, loadBias);
// Store initializer value into output location.
ieContext.createKrnlStoreOp(resultWithBias, alloc, resultIndices);
}

//
auto ipOuterLoopRegion = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
{
// 4. Emit inner loop body
Expand Down Expand Up @@ -350,12 +347,26 @@ struct ONNXConvOpLowering : public ConversionPattern {
auto loadKernel =
ieContext.createKrnlLoadOp(kernelOperand, kernelIndices);
auto loadPartialSum =
ieContext.createKrnlLoadOp(alloc, resultIndices);
rewriter.create<KrnlLoadOp>(loc, reductionVal, ArrayRef<Value>{});
Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
rewriter.create<MulFOp>(loc, loadData, loadKernel));
// 4.4 Store computed value into output location.
ieContext.createKrnlStoreOp(result, alloc, resultIndices);
rewriter.create<KrnlStoreOp>(
loc, result, reductionVal, ArrayRef<Value>{});
}
rewriter.restoreInsertionPoint(ipOuterLoopRegion);

auto result =
rewriter.create<KrnlLoadOp>(loc, reductionVal, ArrayRef<Value>{});
// Store the result. Optionally add bias.
if (hasBias) {
SmallVector<IndexExpr, 4> biasIndices;
biasIndices.emplace_back(kernel);
auto loadBias = ieContext.createKrnlLoadOp(biasOperand, biasIndices);
auto resultWithBias = rewriter.create<AddFOp>(loc, result, loadBias);
ieContext.createKrnlStoreOp(resultWithBias, alloc, resultIndices);
} else
ieContext.createKrnlStoreOp(result, alloc, resultIndices);
}
}
rewriter.replaceOp(op, alloc);
Expand Down
18 changes: 13 additions & 5 deletions src/Conversion/ONNXToKrnl/NN/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,11 @@ struct ONNXPoolOpLowering : public ConversionPattern {
ieContext.createLoopInductionIndex(outputLoops.getInductionVar(i)));

// 2.1 Emit: output[n][c][ho][wo] = identity
ieContext.createKrnlStoreOp(identity, alloc, outputIndices);
// Create a local reduction value for output[n][c][ho][wo].
Value reductionVal = rewriter.create<AllocaOp>(
loc, MemRefType::get({}, memRefType.getElementType()));
rewriter.create<KrnlStoreOp>(
loc, identity, reductionVal, ArrayRef<Value>{});

// 2.2 Emit affine maps which express the lower and upper bounds for the
// pooling window's dimensions.
Expand Down Expand Up @@ -441,7 +445,7 @@ struct ONNXPoolOpLowering : public ConversionPattern {
// Create a krnl iterate.
poolingLoops.createIterateOp();

auto ipOuterLoops = rewriter.saveInsertionPoint();
auto ipOuterLoopRegion = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(poolingLoops.getIterateBlock());
{
// 2.4 Emit the body of the pooling loop nest.
Expand Down Expand Up @@ -472,14 +476,18 @@ struct ONNXPoolOpLowering : public ConversionPattern {
Value loadInput =
ieContext.createKrnlLoadOp(inputOperand, inputIndices);
Value loadPartialOutput =
ieContext.createKrnlLoadOp(alloc, outputIndices);
rewriter.create<KrnlLoadOp>(loc, reductionVal, ArrayRef<Value>{});
Value output = emitScalarOpFor<PoolOp>(rewriter, loc, op,
outputElementType, {loadPartialOutput, loadInput});
ieContext.createKrnlStoreOp(output, alloc, outputIndices);
rewriter.create<KrnlStoreOp>(
loc, output, reductionVal, ArrayRef<Value>{});
}
rewriter.restoreInsertionPoint(ipOuterLoopRegion);
Value output =
rewriter.create<KrnlLoadOp>(loc, reductionVal, ArrayRef<Value>{});
ieContext.createKrnlStoreOp(output, alloc, outputIndices);

// 2.5 Post-processing for the pooling window, e.g. taking average.
rewriter.restoreInsertionPoint(ipOuterLoops);
SmallVector<Value, 4> outputIndicesInValue;
for (IndexExpr expr : outputIndices)
outputIndicesInValue.emplace_back(expr.getValue());
Expand Down
8 changes: 5 additions & 3 deletions test/mlir/onnx/onnx_enable_memory_pool.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ func @test_enable_memory_pool_2(%arg0: tensor<10x10xf32>, %arg1: tensor<10x20xf3
// CHECK: krnl.store [[ADDF1]], [[GETREF1]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: krnl.define_loops
// CHECK: krnl.iterate
// CHECK: [[REDUCTION_SUM:%.+]] = alloca() : memref<f32>
// CHECK: [[LOAD3:%.+]] = krnl.load [[GETREF1]][%arg2, %arg4] : memref<10x10xf32>
// CHECK: [[LOAD4:%.+]] = krnl.load %arg1[%arg4, %arg3] : memref<10x20xf32>
// CHECK: [[LOAD5:%.+]] = krnl.load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
// CHECK: [[LOAD5:%.+]] = krnl.load [[REDUCTION_SUM]][] : memref<f32>
// CHECK: [[MULF1:%.+]] = mulf [[LOAD3]], [[LOAD4]] : f32
// CHECK: [[ADDF2:%.+]] = addf [[LOAD5]], [[MULF1]] : f32
// CHECK: krnl.store [[ADDF2]], [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
// CHECK: krnl.store [[ADDF2]], [[REDUCTION_SUM]][] : memref<f32>
// CHECK: [[SUM:%.+]] = krnl.load [[REDUCTION_SUM]][] : memref<f32>
// CHECK: krnl.store [[SUM]], [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
// CHECK: krnl.define_loops
// CHECK: krnl.iterate
// CHECK: [[LOAD6:%.+]] = krnl.load [[GETREF0]][%arg2, %arg3] : memref<10x20xf32>
Expand Down Expand Up @@ -100,7 +103,6 @@ func @test_enable_memory_pool_3(%arg0: tensor<?x?xf32>, %arg1: tensor<?x10xf32>,
// CHECK: [[DATA3:%.+]] = alloc([[DIM1]]) : memref<?x10xf32>
// CHECK: krnl.define_loops 2
// CHECK: krnl.iterate
// CHECK: krnl.store [[CST]], [[DATA3]][%arg3, %arg4] : memref<?x10xf32>
// CHECK: krnl.define_loops 1
// CHECK: krnl.iterate
// CHECK: krnl.store {{.*}}, [[DATA3]][%arg3, %arg4] : memref<?x10xf32>
Expand Down
Loading

0 comments on commit 7220d1e

Please sign in to comment.