Skip to content

Commit

Permalink
[NFI]: Use dialect builder rather than creating operators directly (l…
Browse files Browse the repository at this point in the history
…lvm#1212)

Signed-off-by: Ettore Tiotto <[email protected]>
  • Loading branch information
Ettore Tiotto authored Mar 8, 2022
1 parent 0e2da77 commit 0a40daa
Show file tree
Hide file tree
Showing 19 changed files with 398 additions and 478 deletions.
22 changes: 11 additions & 11 deletions src/Conversion/ONNXToKrnl/ControlFlow/Loop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ struct ONNXLoopOpLowering : public ConversionPattern {
// Create the loop iteration.
BuildKrnlLoop loop(rewriter, loc, 1);
loop.createDefineOp();
Value maxTripCount =
rewriter.create<KrnlLoadOp>(loc, loopOpAdapter.M()).getResult();
KrnlBuilder createKrnl(rewriter, loc);
Value maxTripCount = createKrnl.load(loopOpAdapter.M());

maxTripCount = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), maxTripCount);
loop.pushBounds(0, maxTripCount);
Expand All @@ -68,7 +69,7 @@ struct ONNXLoopOpLowering : public ConversionPattern {
{
OpBuilder::InsertionGuard insertGuard(rewriter);

auto condReg = rewriter.create<KrnlLoadOp>(loc, cond).getResult();
Value condReg = createKrnl.load(cond);
auto ifOp = rewriter.create<scf::IfOp>(loc, condReg, false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());

Expand All @@ -82,7 +83,7 @@ struct ONNXLoopOpLowering : public ConversionPattern {
MemRefBuilder createMemRef(rewriter, loc);
Value ivMemRef =
createMemRef.alloc(MemRefType::get({}, rewriter.getI64Type()));
rewriter.create<KrnlStoreOp>(loc, iv, ivMemRef);
createKrnl.store(iv, ivMemRef);

// Make the call to loop body function.
SmallVector<Value, 4> params = {ivMemRef, loopOpAdapter.cond()};
Expand Down Expand Up @@ -234,6 +235,7 @@ struct ONNXLoopOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, shouldDealloc);
else {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder> create(rewriter, loc);
auto rankedScanOutTy = memRefType;
SmallVector<mlir::Value, 4> allocParams;
for (int i = 0; i < rankedScanOutTy.getRank(); i++) {
Expand All @@ -244,9 +246,7 @@ struct ONNXLoopOpLowering : public ConversionPattern {
// equal to the max trip count, due to the possibility of early
// termination.
assert(!loopOpAdapter.M().getType().isa<NoneType>());
Value maxTripCount =
rewriter.create<KrnlLoadOp>(loc, loopOpAdapter.M())
.getResult();
Value maxTripCount = create.krnl.load(loopOpAdapter.M());
allocParams.emplace_back(rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), maxTripCount));
} else {
Expand All @@ -258,8 +258,7 @@ struct ONNXLoopOpLowering : public ConversionPattern {
}
}
}
MemRefBuilder createMemRef(rewriter, loc);
alloc = createMemRef.alignedAlloc(rankedScanOutTy, allocParams);
alloc = create.mem.alignedAlloc(rankedScanOutTy, allocParams);
}
outputs.emplace_back(alloc);
}
Expand Down Expand Up @@ -291,8 +290,9 @@ struct ONNXLoopOpLowering : public ConversionPattern {
}
SmallVector<Value, 4> writeIV(writePrefix.begin(), writePrefix.end());
writeIV.insert(writeIV.end(), readIV.begin(), readIV.end());
auto val = rewriter.create<KrnlLoadOp>(loc, src, readIV).getResult();
rewriter.create<KrnlStoreOp>(loc, val, dest, writeIV);
KrnlBuilder createKrnl(rewriter, loc);
Value val = createKrnl.load(src, readIV);
createKrnl.store(val, dest, writeIV);
}
};

Expand Down
30 changes: 16 additions & 14 deletions src/Conversion/ONNXToKrnl/ControlFlow/Scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,52 +281,54 @@ struct ONNXScanOpLowering : public ConversionPattern {
// into a higher dimensional tensor with shape (10x4x2), i.e., a batch of 10
// tensors, each with shape (4x2). To do so, we can invoke emitCopy(src, dest,
// {0}).
static void emitCopy(ConversionPatternRewriter &rewriter, const Location &loc,
static void emitCopy(OpBuilder &builder, const Location &loc,
const Value &src, const Value &dest,
std::vector<Value> writePrefix = {}) {
OpBuilder::InsertionGuard insertGuard(rewriter);
OpBuilder::InsertionGuard insertGuard(builder);

auto srcTy = src.getType().cast<MemRefType>();
SmallVector<Value, 4> readIV;
if (srcTy.getRank() > 0) {
BuildKrnlLoop loop(rewriter, loc, srcTy.getRank());
BuildKrnlLoop loop(builder, loc, srcTy.getRank());
loop.createDefineOp();
for (int i = 0; i < srcTy.getRank(); i++)
loop.pushBounds(0, src, i);
loop.createIterateOp();
rewriter.setInsertionPointToStart(loop.getIterateBlock());
builder.setInsertionPointToStart(loop.getIterateBlock());
auto loopIVs = loop.getAllInductionVar();
readIV = SmallVector<Value, 4>(loopIVs.begin(), loopIVs.end());
}

SmallVector<Value, 4> writeIV(writePrefix.begin(), writePrefix.end());
writeIV.insert(writeIV.end(), readIV.begin(), readIV.end());
auto val = rewriter.create<KrnlLoadOp>(loc, src, readIV).getResult();
rewriter.create<KrnlStoreOp>(loc, val, dest, writeIV);

KrnlBuilder createKrnl(builder, loc);
Value val = createKrnl.load(src, readIV);
createKrnl.store(val, dest, writeIV);
}

static void emitCopyFromTensorSlice(ConversionPatternRewriter &rewriter,
const Location &loc, const Value &src, const Value &dest,
std::vector<Value> readPrefix = {}) {
OpBuilder::InsertionGuard insertGuard(rewriter);
static void emitCopyFromTensorSlice(OpBuilder &builder, const Location &loc,
const Value &src, const Value &dest, std::vector<Value> readPrefix = {}) {
OpBuilder::InsertionGuard insertGuard(builder);

auto srcTy = src.getType().cast<MemRefType>();
SmallVector<Value, 4> readIV(readPrefix.begin(), readPrefix.end());
SmallVector<Value, 4> writeIV;
if ((size_t)srcTy.getRank() > readIV.size()) {
BuildKrnlLoop loop(rewriter, loc, srcTy.getRank() - readPrefix.size());
BuildKrnlLoop loop(builder, loc, srcTy.getRank() - readPrefix.size());
loop.createDefineOp();
for (int i = readIV.size(); i < srcTy.getRank(); i++)
loop.pushBounds(0, src, i);
loop.createIterateOp();
rewriter.setInsertionPointToStart(loop.getIterateBlock());
builder.setInsertionPointToStart(loop.getIterateBlock());
auto IVs = loop.getAllInductionVar();
writeIV.insert(writeIV.end(), IVs.begin(), IVs.end());
readIV.insert(readIV.end(), writeIV.begin(), writeIV.end());
}

auto val = rewriter.create<KrnlLoadOp>(loc, src, readIV).getResult();
rewriter.create<KrnlStoreOp>(loc, val, dest, writeIV);
KrnlBuilder createKrnl(builder, loc);
Value val = createKrnl.load(src, readIV);
createKrnl.store(val, dest, writeIV);
}
};

Expand Down
34 changes: 16 additions & 18 deletions src/Conversion/ONNXToKrnl/Math/Clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,41 +61,39 @@ struct ONNXClipOpLowering : public ConversionPattern {
}

// Load unary first operand.
Value loadedVal = rewriter.create<KrnlLoadOp>(loc, input, loopIVs);
MultiDialectBuilder<KrnlBuilder, MathBuilder> create(rewriter, loc);
Value loadedVal = create.krnl.load(input, loopIVs);
Type inputType = loadedVal.getType();
Value res = loadedVal;

if (inputType.isa<FloatType>()) {
if (!min.getType().isa<NoneType>()) {
Value minVal = rewriter.create<KrnlLoadOp>(loc, min).getResult();
Value lessThanMin = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, res, minVal);
res = rewriter.create<arith::SelectOp>(loc, lessThanMin, minVal, res);
Value minVal = create.krnl.load(min);
Value lessThanMin = create.math.slt(res, minVal);
res = create.math.select(lessThanMin, minVal, res);
}
if (!max.getType().isa<NoneType>()) {
Value maxVal = rewriter.create<KrnlLoadOp>(loc, max).getResult();
Value lessThanMax = rewriter.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, res, maxVal);
res = rewriter.create<arith::SelectOp>(loc, lessThanMax, res, maxVal);
Value maxVal = create.krnl.load(max);
Value lessThanMax = create.math.slt(res, maxVal);
res = create.math.select(lessThanMax, res, maxVal);
}
} else if (inputType.isa<IntegerType>()) {
if (!min.getType().isa<NoneType>()) {
Value minVal = rewriter.create<KrnlLoadOp>(loc, min).getResult();
Value lessThanMin = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, res, minVal);
res = rewriter.create<arith::SelectOp>(loc, lessThanMin, minVal, res);
Value minVal = create.krnl.load(min);
Value lessThanMin = create.math.slt(res, minVal);
res = create.math.select(lessThanMin, minVal, res);
}
if (!max.getType().isa<NoneType>()) {
Value maxVal = rewriter.create<KrnlLoadOp>(loc, max).getResult();
Value lessThanMax = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, res, maxVal);
res = rewriter.create<arith::SelectOp>(loc, lessThanMax, res, maxVal);
Value maxVal = create.krnl.load(max);
Value lessThanMax = create.math.slt(res, maxVal);
res = create.math.select(lessThanMax, res, maxVal);
}
} else {
llvm_unreachable("unsupported element type");
}

// Store result in the resulting array.
rewriter.create<KrnlStoreOp>(loc, res, alloc, loopIVs);
create.krnl.store(res, alloc, loopIVs);

rewriter.replaceOp(op, alloc);
return success();
Expand Down
5 changes: 3 additions & 2 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -903,11 +903,12 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
loopIVs.push_back(arg);
}

auto loadedVal = rewriter.create<KrnlLoadOp>(loc, X, loopIVs);
KrnlBuilder createKrnl(rewriter, loc);
Value loadedVal = createKrnl.load(X, loopIVs);
auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>(
rewriter, loc, op, memRefType.getElementType(), {loadedVal});
// Store result in the resulting array.
rewriter.create<KrnlStoreOp>(loc, loweredOpResult, alloc, loopIVs);
createKrnl.store(loweredOpResult, alloc, loopIVs);

rewriter.replaceOp(op, alloc);
return success();
Expand Down
39 changes: 19 additions & 20 deletions src/Conversion/ONNXToKrnl/Math/LRN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ struct ONNXLRNOpLowering : public ConversionPattern {

// Initialize sum, single scalar, no need for default alignment.
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
MemRefBuilder createMemRef(rewriter, loc);
Value sumAlloc = createMemRef.alloc(scalarMemRefType);
rewriter.create<KrnlStoreOp>(loc,
emitConstantOp(rewriter, loc, elementType, 0), sumAlloc,
ArrayRef<Value>{});
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
rewriter, loc);

Value sumAlloc = create.mem.alloc(scalarMemRefType);
create.krnl.store(emitConstantOp(rewriter, loc, elementType, 0), sumAlloc);

// Create the sum reduction loop
BuildKrnlLoop sumLoops(rewriter, loc, 1);
Expand All @@ -113,13 +113,12 @@ struct ONNXLRNOpLowering : public ConversionPattern {
}
}

Value loadVal = rewriter.create<KrnlLoadOp>(loc, input, loadIndices);
Value squareVal = rewriter.create<arith::MulFOp>(loc, loadVal, loadVal);
Value loadVal = create.krnl.load(input, loadIndices);
Value squareVal = create.math.mul(loadVal, loadVal);

Value sumValue =
rewriter.create<KrnlLoadOp>(loc, sumAlloc, ArrayRef<Value>{});
sumValue = rewriter.create<arith::AddFOp>(loc, sumValue, squareVal);
rewriter.create<KrnlStoreOp>(loc, sumValue, sumAlloc, ArrayRef<Value>{});
Value sumValue = create.krnl.load(sumAlloc, ArrayRef<Value>{});
sumValue = create.math.add(sumValue, squareVal);
create.krnl.store(sumValue, sumAlloc, ArrayRef<Value>{});

// Compute and store the output
// y = x / ((bias + (alpha / nsize) * square_sum) ** beta)
Expand All @@ -128,15 +127,15 @@ struct ONNXLRNOpLowering : public ConversionPattern {
for (int i = 0; i < outputRank; ++i) {
storeIndices.emplace_back(outputLoops.getInductionVar(i));
}
Value xValue = rewriter.create<KrnlLoadOp>(loc, input, storeIndices);
sumValue = rewriter.create<KrnlLoadOp>(loc, sumAlloc, ArrayRef<Value>{});
Value tempValue = rewriter.create<math::PowFOp>(loc,
rewriter.create<arith::AddFOp>(loc, biasValue,
rewriter.create<arith::MulFOp>(loc, alphaDivSizeValue, sumValue)),
betaValue);
Value resultValue = rewriter.create<arith::DivFOp>(loc, xValue, tempValue);

rewriter.create<KrnlStoreOp>(loc, resultValue, alloc, storeIndices);
Value xValue = create.krnl.load(input, storeIndices);
sumValue = create.krnl.load(sumAlloc);
Value tempValue =
create.math.pow(create.math.add(biasValue,
create.math.mul(alphaDivSizeValue, sumValue)),
betaValue);
Value resultValue = create.math.div(xValue, tempValue);

create.krnl.store(resultValue, alloc, storeIndices);

rewriter.replaceOp(op, alloc);

Expand Down
Loading

0 comments on commit 0a40daa

Please sign in to comment.