Skip to content

Commit

Permalink
Use builder based interface to generate Krnl loops (llvm#1250)
Browse files Browse the repository at this point in the history
* [ClipOp]: Use builder based interface to generate Krnl loops

Signed-off-by: Ettore Tiotto <[email protected]>

* [TileOp]: Use builder based interface to generate Krnl loops

Signed-off-by: Ettore Tiotto <[email protected]>

* [Transpose]: Use builder based interface to generate Krnl loops

Signed-off-by: Ettore Tiotto <[email protected]>

* [Transpose]: Remove #ifdef out code

Signed-off-by: Ettore Tiotto <[email protected]>

* [Split]: Use builder based interface to generate Krnl loops

Signed-off-by: Ettore Tiotto <[email protected]>

* Address code review comments

Signed-off-by: Ettore Tiotto <[email protected]>

Co-authored-by: Tung D. Le <[email protected]>
  • Loading branch information
Ettore Tiotto and tungld authored Mar 25, 2022
1 parent bbda7b2 commit f8f0a1e
Show file tree
Hide file tree
Showing 10 changed files with 380 additions and 299 deletions.
114 changes: 54 additions & 60 deletions src/Conversion/ONNXToKrnl/Math/Clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,73 +28,67 @@ struct ONNXClipOpLowering : public ConversionPattern {
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = operands[0];
Value min = operands[1];
Value max = operands[2];
ONNXClipOp clipOp = cast<ONNXClipOp>(op);
MemRefType memRefType = convertToMemRefType(*op->result_type_begin());

// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin());

Value alloc;
bool insertDealloc = checkInsertDealloc(op);
ONNXClipOpAdaptor operandAdaptor(operands);
ONNXClipOpShapeHelper shapeHelper(&clipOp, &rewriter,
getDenseElementAttributeFromKrnlValue,
loadDenseElementArrayValueAtIndex);
auto shapeComputed = shapeHelper.computeShape(operandAdaptor);
assert(succeeded(shapeComputed));

if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, input);
Value input = operandAdaptor.input();
Value min = operandAdaptor.min();
Value max = operandAdaptor.max();

SmallVector<Value, 4> loopIVs;
// Only create krnl.iterate if one of the operands is not scalar tensor.
// Insert an allocation and deallocation for the result of this operation.
bool insertDealloc = checkInsertDealloc(op);
Value alloc =
(hasAllConstantDimensions(memRefType))
? insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc)
: insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, input);

auto computeResult =
[&](MultiDialectBuilder<KrnlBuilder, MathBuilder> &create,
const ValueRange &indices) {
Value loadedVal = create.krnl.load(input, indices);
Value res = loadedVal;
if (!min.getType().isa<NoneType>()) {
Value minVal = create.krnl.load(min);
Value lessThanMin = create.math.slt(res, minVal);
res = create.math.select(lessThanMin, minVal, res);
}
if (!max.getType().isa<NoneType>()) {
Value maxVal = create.krnl.load(max);
Value lessThanMax = create.math.slt(res, maxVal);
res = create.math.select(lessThanMax, res, maxVal);
}
create.krnl.store(res, alloc, indices);
};

// Create a loop only is one of the operands is not a scalar tensor.
if (!hasAllScalarValues(operands)) {
// Create iterateOp & get block within iterate op.
BuildKrnlLoop loops(rewriter, loc, memRefType.getRank());
loops.createDefineAndIterateOp(input);
Block *iterationBlock = loops.getIterateBlock();

// Insert instructions inside the KernelIterateOp body.
rewriter.setInsertionPointToStart(iterationBlock);

// Handle the operation:
for (auto arg : iterationBlock->getArguments())
loopIVs.push_back(arg);
}

// Load unary first operand.
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
Value loadedVal = create.krnl.load(input, loopIVs);
Type inputType = loadedVal.getType();
Value res = loadedVal;

if (inputType.isa<FloatType>()) {
if (!min.getType().isa<NoneType>()) {
Value minVal = create.krnl.load(min);
Value lessThanMin = create.math.slt(res, minVal);
res = create.math.select(lessThanMin, minVal, res);
}
if (!max.getType().isa<NoneType>()) {
Value maxVal = create.krnl.load(max);
Value lessThanMax = create.math.slt(res, maxVal);
res = create.math.select(lessThanMax, res, maxVal);
}
} else if (inputType.isa<IntegerType>()) {
if (!min.getType().isa<NoneType>()) {
Value minVal = create.krnl.load(min);
Value lessThanMin = create.math.slt(res, minVal);
res = create.math.select(lessThanMin, minVal, res);
}
if (!max.getType().isa<NoneType>()) {
Value maxVal = create.krnl.load(max);
Value lessThanMax = create.math.slt(res, maxVal);
res = create.math.select(lessThanMax, res, maxVal);
}
KrnlBuilder createKrnl(rewriter, loc);
uint64_t numLoops = memRefType.getRank();
ValueRange loopDef = createKrnl.defineLoops(numLoops);

SmallVector<IndexExpr, 4> lbs(numLoops, LiteralIndexExpr(0));
SmallVector<IndexExpr, 4> ubs;
for (uint64_t i = 0; i < numLoops; ++i)
ubs.emplace_back(shapeHelper.dimsForOutput()[i]);

createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
[&](KrnlBuilder &createKrnl, ValueRange indices) {
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(createKrnl);
computeResult(create, indices);
});
} else {
llvm_unreachable("unsupported element type");
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
computeResult(create, {});
}

// Store result in the resulting array.
create.krnl.store(res, alloc, loopIVs);

rewriter.replaceOp(op, alloc);
return success();
}
Expand Down
69 changes: 37 additions & 32 deletions src/Conversion/ONNXToKrnl/Tensor/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ template <typename Adaptor, typename Op, typename ShapeHelper>
LogicalResult ONNXSplitOpLoweringCommon(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// Gather info.
auto loc = op->getLoc();
Location loc = op->getLoc();
Adaptor operandAdaptor(operands);
Op splitOp = llvm::dyn_cast<Op>(op);
auto rank = splitOp.input().getType().template cast<ShapedType>().getRank();
auto outputNum = splitOp.getNumResults();
auto axis = splitOp.axis();
Op splitOp = cast<Op>(op);
uint64_t rank =
splitOp.input().getType().template cast<ShapedType>().getRank();
unsigned outputNum = splitOp.getNumResults();
unsigned axis = splitOp.axis();

// Get a shape helper.
ShapeHelper shapeHelper(&splitOp, &rewriter,
Expand All @@ -36,7 +37,7 @@ LogicalResult ONNXSplitOpLoweringCommon(Operation *op, ArrayRef<Value> operands,

// Alloc and dealloc.
SmallVector<Value, 4> allocs;
for (unsigned int i = 0; i < outputNum; ++i) {
for (unsigned i = 0; i < outputNum; ++i) {
checkInsertDealloc(op, i);
auto memRefType = convertToMemRefType(splitOp.outputs()[i].getType());
Value alloc = insertAllocAndDeallocSimple(
Expand All @@ -45,40 +46,44 @@ LogicalResult ONNXSplitOpLoweringCommon(Operation *op, ArrayRef<Value> operands,
}

// Creates loops, one for each output.
for (unsigned int i = 0; i < outputNum; ++i) {
for (unsigned i = 0; i < outputNum; ++i) {
OpBuilder::InsertionGuard insertGuard(rewriter);
// Create loop.
BuildKrnlLoop outputLoops(rewriter, loc, rank);
outputLoops.createDefineAndIterateOp(allocs[i]);
rewriter.setInsertionPointToStart(outputLoops.getIterateBlock());

// Scope for krnl ops
IndexExprScope childScope(&rewriter, shapeHelper.scope);

KrnlBuilder createKrnl(rewriter, loc);
ValueRange loopDef = createKrnl.defineLoops(rank);
SmallVector<IndexExpr, 4> lbs(rank, LiteralIndexExpr(0));

MemRefBoundsIndexCapture allocsBounds(allocs[i]);
SmallVector<IndexExpr, 4> ubs;
allocsBounds.getDimList(ubs);

createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
[&](KrnlBuilder &createKrnl, ValueRange indices) {
SmallVector<IndexExpr, 4> readIndices;
for (uint64_t r = 0; r < rank; ++r) {
DimIndexExpr readIndex(indices[r]);
// Compute read index for the split axis.
if (r == axis)
for (unsigned k = 0; k < i; ++k) {
SymbolIndexExpr splitDim(shapeHelper.dimsForOutput(k)[r]);
readIndex = readIndex + splitDim;
}

// Indices for the read and write.
SmallVector<IndexExpr, 4> readIndices;
SmallVector<IndexExpr, 4> writeIndices;
for (int r = 0; r < rank; ++r) {
Value readVal = outputLoops.getInductionVar(r);
// If not the split axis, same index for read and write
IndexExpr readIndex = DimIndexExpr(readVal);
DimIndexExpr writeIndex(readVal);
// If the split axis, compute read index for the split axis.
if (r == axis) {
for (unsigned int k = 0; k < i; ++k) {
IndexExpr splitDim = SymbolIndexExpr(shapeHelper.dimsForOutput(k)[r]);
readIndex = readIndex + splitDim;
}
}
readIndices.emplace_back(readIndex);
writeIndices.emplace_back(writeIndex);
}
// Insert copy.
Value loadData = createKrnl.loadIE(operandAdaptor.input(), readIndices);
createKrnl.storeIE(loadData, allocs[i], writeIndices);
readIndices.emplace_back(readIndex);
}

// Insert copy.
Value loadData =
createKrnl.loadIE(operandAdaptor.input(), readIndices);
createKrnl.store(loadData, allocs[i], indices);
});
}

rewriter.replaceOp(op, allocs);

return success();
}

Expand Down
65 changes: 26 additions & 39 deletions src/Conversion/ONNXToKrnl/Tensor/Tile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,48 +71,35 @@ struct ONNXTileOpLowering : public ConversionPattern {
(void)shapecomputed;
assert(!failed(shapecomputed) && "expected to succeed");

MemRefType outputMemRefType = convertToMemRefType(*op->result_type_begin());
auto outputMemRefShape = outputMemRefType.getShape();
int64_t outputRank = outputMemRefShape.size();
MemRefType memRefType = convertToMemRefType(*op->result_type_begin());
llvm::ArrayRef<int64_t> memRefShape = memRefType.getShape();
uint64_t outputRank = memRefShape.size();

Value input = operandAdaptor.input();

Value alloc = insertAllocAndDeallocSimple(
rewriter, op, outputMemRefType, loc, shapeHelper.dimsForOutput(0));

// Define loops and iteration trip counts (equivalent to size of output)
BuildKrnlLoop outputLoops(rewriter, loc, outputRank);
outputLoops.createDefineOp();
outputLoops.pushAllBounds(shapeHelper.dimsForOutput(0));
outputLoops.createIterateOp();
rewriter.setInsertionPointToStart(outputLoops.getIterateBlock());

SmallVector<Value, 4> loadIndices;
// This implementation is to iterate the output tensor.
// The store has simple affine subscript expression.
// Alternative implementation is to iterate the input tensor and repeats.
// The load of elements in input tensor can be reused explicitly.
// But the subscript of store is not contigious, or even not affine.
// Alternative implementation can be found at the end of this file.

for (int64_t i = 0; i < outputRank; i++) {
// Scope is created for each dimension because they are independent
IndexExprScope IEScope(&rewriter, loc);
DimIndexExpr index(outputLoops.getInductionVar(i));
MemRefBoundsIndexCapture inputBounds(input);
DimIndexExpr dimSize(inputBounds.getDim(i));
IndexExpr exprVal = index % dimSize;
loadIndices.emplace_back(exprVal.getValue());
}

MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
rewriter, loc);
Value loadVal = create.krnl.load(input, loadIndices);

SmallVector<Value, 4> storeIndices;
for (int64_t i = 0; i < outputRank; ++i)
storeIndices.emplace_back(outputLoops.getInductionVar(i));
create.krnl.store(loadVal, alloc, storeIndices);
rewriter, op, memRefType, loc, shapeHelper.dimsForOutput());

KrnlBuilder createKrnl(rewriter, loc);
ValueRange loopDef = createKrnl.defineLoops(outputRank);
SmallVector<IndexExpr, 4> lbs(outputRank, LiteralIndexExpr(0));

MemRefBoundsIndexCapture inputBounds(input);
createKrnl.iterateIE(loopDef, loopDef, lbs, shapeHelper.dimsForOutput(),
[&](KrnlBuilder &createKrnl, ValueRange indices) {
// Compute the indices used by the input tensor load operation.
// Note: An alternative implementation can be found at the end of this
// file.
SmallVector<Value, 4> loadIndices;
for (uint64_t i = 0; i < outputRank; ++i) {
DimIndexExpr index(indices[i]);
DimIndexExpr dimSize(inputBounds.getDim(i));
IndexExpr exprVal = index % dimSize;
loadIndices.emplace_back(exprVal.getValue());
}

Value loadVal = createKrnl.load(input, loadIndices);
createKrnl.store(loadVal, alloc, indices);
});

rewriter.replaceOp(op, alloc);

Expand Down
50 changes: 23 additions & 27 deletions src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct ONNXTransposeOpLowering : public ConversionPattern {

// Basic information.
auto memRefType = convertToMemRefType(*op->result_type_begin());
int64_t rank = memRefType.getShape().size();
uint64_t rank = memRefType.getShape().size();

// Get a shape helper.
ONNXTransposeOpShapeHelper shapeHelper(&transposeOp, &rewriter,
Expand All @@ -46,32 +46,28 @@ struct ONNXTransposeOpLowering : public ConversionPattern {

// Insert an allocation and deallocation for the result of this operation.
Value alloc = insertAllocAndDeallocSimple(
rewriter, op, memRefType, loc, shapeHelper.dimsForOutput(0));

// Create loop.
BuildKrnlLoop inputLoops(rewriter, loc, rank);
inputLoops.createDefineAndIterateOp(data);
rewriter.setInsertionPointToStart(inputLoops.getIterateBlock());
{
// Get a child IndexExpr context.
IndexExprScope childScope(&rewriter, shapeHelper.scope);
KrnlBuilder createKrnl(rewriter, loc);

// Get read/write indices.
SmallVector<IndexExpr, 4> readIndices;
SmallVector<IndexExpr, 4> writeIndices;
for (decltype(rank) i = 0; i < rank; ++i) {
Value readVal = inputLoops.getInductionVar(i);
Value writeVal =
inputLoops.getInductionVar(ArrayAttrIntVal(permAttr, i));
readIndices.emplace_back(DimIndexExpr(readVal));
writeIndices.emplace_back(DimIndexExpr(writeVal));
}

// Copy data.
Value loadData = createKrnl.loadIE(data, readIndices);
createKrnl.storeIE(loadData, alloc, writeIndices);
}
rewriter, op, memRefType, loc, shapeHelper.dimsForOutput());

KrnlBuilder createKrnl(rewriter, loc);
ValueRange loopDef = createKrnl.defineLoops(rank);
SmallVector<IndexExpr, 4> lbs(rank, LiteralIndexExpr(0));

MemRefBoundsIndexCapture dataBounds(data);
SmallVector<IndexExpr, 4> ubs;
dataBounds.getDimList(ubs);

createKrnl.iterateIE(loopDef, loopDef, lbs, ubs,
[&](KrnlBuilder &createKrnl, ValueRange indices) {
// Compute the indices used by the load operation.
SmallVector<IndexExpr, 4> storeIndices;
for (uint64_t i = 0; i < rank; ++i) {
Value index = indices[ArrayAttrIntVal(permAttr, i)];
storeIndices.emplace_back(DimIndexExpr(index));
}

Value loadData = createKrnl.load(data, indices);
createKrnl.storeIE(loadData, alloc, storeIndices);
});

rewriter.replaceOp(op, alloc);

Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_onnx_mlir_library(OMONNXOps
ShapeInference/ArgMax.cpp
ShapeInference/AveragePool.cpp
ShapeInference/CategoryMapper.cpp
ShapeInference/Clip.cpp
ShapeInference/Compress.cpp
ShapeInference/Concat.cpp
ShapeInference/Conv.cpp
Expand Down
Loading

0 comments on commit f8f0a1e

Please sign in to comment.