Skip to content

Commit

Permalink
[Concat] Use builder based interface to generate Krnl loops (llvm#1316)
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored Apr 5, 2022
1 parent 600e28b commit 1869fa6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 35 deletions.
54 changes: 27 additions & 27 deletions src/Conversion/ONNXToKrnl/Tensor/Concat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,36 +50,36 @@ struct ONNXConcatOpLowering : public ConversionPattern {
MultiDialectBuilder<KrnlBuilder> create(rewriter, loc);

// Creates loops, one for each input.
KrnlBuilder createKrnl(rewriter, loc);
for (unsigned int i = 0; i < inputNum; ++i) {
OpBuilder::InsertionGuard insertGuard(rewriter);
// Create loop.
BuildKrnlLoop inputLoops(rewriter, loc, rank);
inputLoops.createDefineOp();
for (unsigned int r = 0; r < rank; ++r)
inputLoops.pushBounds(0, operands[i], r);
inputLoops.createIterateOp();
rewriter.setInsertionPointToStart(inputLoops.getIterateBlock());

// Indices for the read and write.
SmallVector<Value, 4> readIndices;
SmallVector<Value, 4> writeIndices;
for (unsigned int r = 0; r < rank; ++r) {
readIndices.emplace_back(inputLoops.getInductionVar(r));
if (r != axis || i == 0)
writeIndices.emplace_back(inputLoops.getInductionVar(r));
else {
IndexExprScope IEScope(&rewriter, loc);
IndexExpr writeOffset = DimIndexExpr(inputLoops.getInductionVar(r));
for (unsigned int j = 0; j < i; j++) {
MemRefBoundsIndexCapture operandJBounds(operands[j]);
writeOffset = writeOffset + operandJBounds.getDim(r);
}
writeIndices.emplace_back(writeOffset.getValue());
}
}
// Insert copy.
Value loadData = create.krnl.load(operands[i], readIndices);
create.krnl.store(loadData, alloc, writeIndices);
ValueRange loopDef = createKrnl.defineLoops(rank);
SmallVector<IndexExpr, 4> lbs(rank, LiteralIndexExpr(0));
MemRefBoundsIndexCapture bounds(operands[i]);
SmallVector<IndexExpr, 4> ubs;
bounds.getDimList(ubs);
createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
[&](KrnlBuilder &createKrnl, ValueRange loopInd) {
// Indices for the read and write.
SmallVector<Value, 4> readIndices, writeIndices;
for (unsigned int r = 0; r < rank; ++r) {
if (r != axis || i == 0)
writeIndices.emplace_back(loopInd[r]);
else {
IndexExprScope IEScope(&rewriter, loc);
IndexExpr writeOffset = DimIndexExpr(loopInd[r]);
for (unsigned int j = 0; j < i; j++) {
MemRefBoundsIndexCapture operandJBounds(operands[j]);
writeOffset = writeOffset + operandJBounds.getDim(r);
}
writeIndices.emplace_back(writeOffset.getValue());
}
}
// Insert copy.
Value loadData = createKrnl.load(operands[i], loopInd);
createKrnl.store(loadData, alloc, writeIndices);
});
}
rewriter.replaceOp(op, alloc);
return success();
Expand Down
17 changes: 9 additions & 8 deletions test/mlir/onnx/onnx_lowering_with_canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -872,20 +872,21 @@ func private @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x3
// CHECK: [[RES:%.+]] = memref.alloc() {{.*}}: memref<5x5x9x32xf32>
// CHECK: [[DEF_LOOPS0:%.+]]:4 = krnl.define_loops 4
// CHECK: krnl.iterate([[DEF_LOOPS0]]#0, [[DEF_LOOPS0]]#1, [[DEF_LOOPS0]]#2, [[DEF_LOOPS0]]#3) with ([[DEF_LOOPS0]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS0]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS0]]#2 -> %arg5 = 0 to 1, [[DEF_LOOPS0]]#3 -> %arg6 = 0 to 32){
// CHECK: [[LOAD0:%.+]] = krnl.load %arg0[%arg3, %arg4, %arg5, %arg6] : memref<5x5x1x32xf32>
// CHECK: krnl.store [[LOAD0]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<5x5x9x32xf32>
// CHECK: [[IV:%.+]]:4 = krnl.get_induction_var_value([[DEF_LOOPS0]]#0, [[DEF_LOOPS0]]#1, [[DEF_LOOPS0]]#2, [[DEF_LOOPS0]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index)
// CHECK: [[LOAD0:%.+]] = krnl.load %arg0[[[IV]]#0, [[IV]]#1, [[IV]]#2, [[IV]]#3] : memref<5x5x1x32xf32>
// CHECK: krnl.store [[LOAD0]], [[RES]][[[IV]]#0, [[IV]]#1, [[IV]]#2, [[IV]]#3] : memref<5x5x9x32xf32>

// CHECK: [[DEF_LOOPS1:%.+]]:4 = krnl.define_loops 4
// CHECK: krnl.iterate([[DEF_LOOPS1]]#0, [[DEF_LOOPS1]]#1, [[DEF_LOOPS1]]#2, [[DEF_LOOPS1]]#3) with ([[DEF_LOOPS1]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS1]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS1]]#2 -> %arg5 = 0 to 3, [[DEF_LOOPS1]]#3 -> %arg6 = 0 to 32){
// CHECK: [[AFFINE_APPLY1:%.+]] = affine.apply #{{.*}}(%arg5)
// CHECK: [[LOAD1:%.+]] = krnl.load %arg1[%arg3, %arg4, %arg5, %arg6] : memref<5x5x3x32xf32>
// CHECK: krnl.store [[LOAD1]], [[RES]][%arg3, %arg4, [[AFFINE_APPLY1]], %arg6] : memref<5x5x9x32xf32>
// CHECK: [[AFFINE_APPLY1:%.+]] = affine.apply #{{.*}}([[IV]]#2)
// CHECK: [[LOAD1:%.+]] = krnl.load %arg1[[[IV]]#0, [[IV]]#1, [[IV]]#2, [[IV]]#3] : memref<5x5x3x32xf32>
// CHECK: krnl.store [[LOAD1]], [[RES]][[[IV]]#0, [[IV]]#1, [[AFFINE_APPLY1]], [[IV]]#3] : memref<5x5x9x32xf32>

// CHECK: [[DEF_LOOPS2:%.+]]:4 = krnl.define_loops 4
// CHECK: krnl.iterate([[DEF_LOOPS2]]#0, [[DEF_LOOPS2]]#1, [[DEF_LOOPS2]]#2, [[DEF_LOOPS2]]#3) with ([[DEF_LOOPS2]]#0 -> %arg3 = 0 to 5, [[DEF_LOOPS2]]#1 -> %arg4 = 0 to 5, [[DEF_LOOPS2]]#2 -> %arg5 = 0 to 5, [[DEF_LOOPS2]]#3 -> %arg6 = 0 to 32){
// CHECK: [[AFFINE_APPLY2:%.+]] = affine.apply #{{.*}}(%arg5)
// CHECK: [[LOAD2:%.+]] = krnl.load %arg2[%arg3, %arg4, %arg5, %arg6] : memref<5x5x5x32xf32>
// CHECK: krnl.store [[LOAD2]], [[RES]][%arg3, %arg4, [[AFFINE_APPLY2]], %arg6] : memref<5x5x9x32xf32>
// CHECK: [[AFFINE_APPLY2:%.+]] = affine.apply #{{.*}}([[IV]]#2)
// CHECK: [[LOAD2:%.+]] = krnl.load %arg2[[[IV]]#0, [[IV]]#1, [[IV]]#2, [[IV]]#3] : memref<5x5x5x32xf32>
// CHECK: krnl.store [[LOAD2]], [[RES]][[[IV]]#0, [[IV]]#1, [[AFFINE_APPLY2]], [[IV]]#3] : memref<5x5x9x32xf32>

// CHECK: return [[RES]] : memref<5x5x9x32xf32>
}
Expand Down

0 comments on commit 1869fa6

Please sign in to comment.