Skip to content

Commit

Permalink
Use IndexExpr to infer shape for Conv/Pooling (llvm#764)
Browse files Browse the repository at this point in the history
* Test conv with dynamic dimensions

Signed-off-by: Tung D. Le <[email protected]>

* Borrow insertAllocandDealloc from the pooling ops

Signed-off-by: Tung D. Le <[email protected]>

* Test ResNet50 with unknown dimensions

Signed-off-by: Tung D. Le <[email protected]>

* Use IndexExpr for conv

Signed-off-by: Tung D. Le <[email protected]>

* Clean up

Signed-off-by: Tung D. Le <[email protected]>

* Clean up

Signed-off-by: Tung D. Le <[email protected]>

* Use IndexExpr for pooling ops

Signed-off-by: Tung D. Le <[email protected]>

* clang-format

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Jul 12, 2021
1 parent 2a0740d commit 4b6b99d
Show file tree
Hide file tree
Showing 9 changed files with 364 additions and 136 deletions.
20 changes: 12 additions & 8 deletions src/Conversion/ONNXToKrnl/NN/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/ONNXShapeHelper.hpp"

using namespace mlir;

Expand Down Expand Up @@ -59,6 +60,15 @@ struct ONNXConvOpLowering : public ConversionPattern {
for (Attribute stride : stridesAttribute.getValue())
strides.emplace_back(stride.cast<IntegerAttr>().getInt());

// Get shape.
ONNXConvOpShapeHelper shapeHelper(&convOp, rewriter,
getDenseElementAttributeFromKrnlValue,
loadDenseElementArrayValueAtIndex);
auto shapecomputed =
shapeHelper.Compute(operandAdaptor, convOp.kernel_shape(),
padsAttribute, stridesAttribute, convOp.dilations());
assert(succeeded(shapecomputed));

// Scope for krnl ops
IndexExprScope ieScope(rewriter, loc);
KrnlBuilder createKrnl(rewriter, loc);
Expand All @@ -69,8 +79,8 @@ struct ONNXConvOpLowering : public ConversionPattern {

// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
Value alloc = insertAllocAndDeallocSimple(
rewriter, op, memRefType, loc, shapeHelper.dimsForOutput(0));

auto resultShape = memRefType.getShape();
auto inputOperand = operandAdaptor.X();
Expand All @@ -79,12 +89,6 @@ struct ONNXConvOpLowering : public ConversionPattern {
auto biasOperand = operandAdaptor.B();
bool hasBias = !biasOperand.getType().isa<NoneType>();

if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {inputOperand});

// R = Conv(D, K)
//
// The input/output shapes will look like this:
Expand Down
108 changes: 35 additions & 73 deletions src/Conversion/ONNXToKrnl/NN/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/ONNXShapeHelper.hpp"

using namespace mlir;

Expand Down Expand Up @@ -74,6 +75,21 @@ std::vector<int64_t> getDilations<ONNXMaxPoolSingleOutOp>(
return dilations;
}

//===----------------------------------------------------------------------===//
// Get dilation attribute.
//
template <typename PoolOp>
llvm::Optional<ArrayAttr> getDilationAttr(PoolOp poolOp) {
return llvm::None;
}

// MaxPool has dilations attribute.
template <>
llvm::Optional<ArrayAttr> getDilationAttr<ONNXMaxPoolSingleOutOp>(
ONNXMaxPoolSingleOutOp poolOp) {
return poolOp.dilations();
}

//===----------------------------------------------------------------------===//
// Get count_include_pad values
//
Expand Down Expand Up @@ -127,75 +143,17 @@ void postProcessPoolingWindow<ONNXAveragePoolOp>(
rewriter.create<KrnlStoreOp>(loc, average, alloc, resultIndices);
}

//===----------------------------------------------------------------------===//
// Helper function to insert alloc and dealloc ops for memref of dynamic shape.
//
Value insertAllocAndDeallocForPooling(ConversionPatternRewriter &rewriter,
Location loc, bool insertDealloc, MemRefType memRefType, Value inputOperand,
ArrayRef<int64_t> kernelShape, ArrayRef<int64_t> pads,
ArrayRef<int64_t> strides, ArrayRef<int64_t> dilations, bool ceilMode) {
memref::AllocOp alloc;

// Shape and rank information related to result and kernel.
auto resultShape = memRefType.getShape();
auto resultRank = resultShape.size();
auto kernelRank = kernelShape.size();
int kernelOffset = resultRank - kernelRank;

// Compute dimensions of the result of this operation.
SmallVector<Value, 2> allocOperands;
for (int i = 0; i < kernelOffset; ++i) {
if (resultShape[i] < 0) {
auto dim = rewriter.create<memref::DimOp>(loc, inputOperand, i);
allocOperands.emplace_back(dim);
}
}

// Obtain an affine map to compute the output dimension.
AffineMap dimMap = getConvDimMap(rewriter, ceilMode);
for (int i = kernelOffset; i < (int)resultShape.size(); ++i) {
if (resultShape[i] < 0) {
int spatialIndex = i - kernelOffset;
// Prepare arguments for the affine map.
SmallVector<Value, 4> dimArgs;
dimArgs.emplace_back(
rewriter.create<memref::DimOp>(loc, inputOperand, i));
dimArgs.emplace_back(emitConstantOp(
rewriter, loc, rewriter.getIndexType(), kernelShape[spatialIndex]));
dimArgs.emplace_back(
emitConstantOp(rewriter, loc, rewriter.getIndexType(),
(pads[spatialIndex] + pads[spatialIndex + kernelRank])));
dimArgs.emplace_back(emitConstantOp(
rewriter, loc, rewriter.getIndexType(), strides[spatialIndex]));
dimArgs.emplace_back(
emitConstantOp(rewriter, loc, rewriter.getIndexType(),
dilations.empty() ? 1 : dilations[spatialIndex]));

// Apply the affine map.
Value dimVal = rewriter.create<AffineApplyOp>(loc, dimMap, dimArgs);
allocOperands.emplace_back(dimVal);
}
}
alloc = rewriter.create<memref::AllocOp>(loc, memRefType, allocOperands);
if (insertDealloc) {
auto *parentBlock = alloc.getOperation()->getBlock();
auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
return alloc;
}

//===----------------------------------------------------------------------===//
// Template function that does pooling.
//
template <typename PoolOp>
template <typename PoolOp, typename PoolOpAdaptor>
struct ONNXPoolOpLowering : public ConversionPattern {
ONNXPoolOpLowering(MLIRContext *ctx)
: ConversionPattern(PoolOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
ONNXMaxPoolSingleOutOpAdaptor operandAdaptor(operands);
PoolOpAdaptor operandAdaptor(operands);
auto loc = op->getLoc();

PoolOp poolOp = llvm::dyn_cast<PoolOp>(op);
Expand Down Expand Up @@ -225,8 +183,17 @@ struct ONNXPoolOpLowering : public ConversionPattern {
std::vector<int64_t> dilations = getDilations<PoolOp>(poolOp);
bool isDilated = !dilations.empty();

// Get shape.
ONNXPoolOpShapeHelper<PoolOp, PoolOpAdaptor> shapeHelper(&poolOp, rewriter,
getDenseElementAttributeFromKrnlValue,
loadDenseElementArrayValueAtIndex);
auto shapecomputed = shapeHelper.Compute(operandAdaptor,
poolOp.kernel_shape(), padsAttribute, stridesAttribute,
getDilationAttr<PoolOp>(poolOp), ceilMode);
assert(succeeded(shapecomputed));

// Type information about the input and result of this operation.
auto inputOperand = operandAdaptor.X();
auto inputOperand = (Value)operandAdaptor.X();
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto outputShape = memRefType.getShape();
Expand All @@ -240,16 +207,8 @@ struct ONNXPoolOpLowering : public ConversionPattern {
KrnlBuilder createKrnl(rewriter, loc);

// Insert an allocation and deallocation for the output of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);

if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else {
alloc = insertAllocAndDeallocForPooling(rewriter, loc, insertDealloc,
memRefType, inputOperand, kernelShape, pads, strides, dilations,
ceilMode);
}
Value alloc = insertAllocAndDeallocSimple(
rewriter, op, memRefType, loc, shapeHelper.dimsForOutput(0));

// input = Pool(output)
//
Expand Down Expand Up @@ -505,6 +464,9 @@ struct ONNXPoolOpLowering : public ConversionPattern {

void populateLoweringONNXPoolingOpPattern(
RewritePatternSet &patterns, MLIRContext *ctx) {
patterns.insert<ONNXPoolOpLowering<ONNXMaxPoolSingleOutOp>>(ctx);
patterns.insert<ONNXPoolOpLowering<ONNXAveragePoolOp>>(ctx);
patterns.insert<ONNXPoolOpLowering<ONNXMaxPoolSingleOutOp,
ONNXMaxPoolSingleOutOpAdaptor>>(ctx);
patterns
.insert<ONNXPoolOpLowering<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor>>(
ctx);
}
52 changes: 23 additions & 29 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1714,17 +1714,14 @@ LogicalResult ONNXConvOp::inferShapes(
auto stridesOpt = strides();
auto padsOpt = pads();

// First two output dimensions consist of the number of batches and the
// number of kernels being applied.
// Infer shape for the output.
ONNXConvOpAdaptor operandAdaptor(*this);
ONNXConvOpShapeHelper shapeHelper(this);
if (failed(shapeHelper.Compute(
operandAdaptor, kernelShape, padsOpt, stridesOpt, dilationsOpt)))
return emitError("Failed to scan Conv parameters successfully");
SmallVector<int64_t, 4> outputDims;
// Insert batch size.
outputDims.emplace_back(xShape[0]);
// Insert number of filters being applied (number of output channels).
outputDims.emplace_back(weightShape[0]);
// Compute and insert spatial dims.
insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
stridesOpt, dilationsOpt);

IndexExpr::getShape(shapeHelper.dimsForOutput(0), outputDims);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return success();
}
Expand Down Expand Up @@ -1994,8 +1991,6 @@ LogicalResult ONNXAveragePoolOp::inferShapes(
if (!X().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");

auto builder = mlir::Builder(getContext());

// Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape();
Expand All @@ -2021,14 +2016,15 @@ LogicalResult ONNXAveragePoolOp::inferShapes(
return res;
auto padsOpt = pads();

// Infer shape for the output.
ONNXAveragePoolOpAdaptor operandAdaptor(*this);
ONNXPoolOpShapeHelper<ONNXAveragePoolOp, ONNXAveragePoolOpAdaptor>
shapeHelper(this);
if (failed(shapeHelper.Compute(operandAdaptor, kernelShape, padsOpt,
stridesOpt, llvm::None, ceilMode)))
return emitError("Failed to scan AveragePool parameters successfully");
SmallVector<int64_t, 4> outputDims;
// Insert batch size.
outputDims.emplace_back(xShape[0]);
outputDims.emplace_back(xShape[1]);
// Compute and insert spatial dims.
insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
stridesOpt, llvm::None, ceilMode);

IndexExpr::getShape(shapeHelper.dimsForOutput(0), outputDims);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return success();
}
Expand All @@ -2048,11 +2044,8 @@ LogicalResult ONNXMaxPoolSingleOutOp::inferShapes(
if (!X().getType().isa<RankedTensorType>())
return emitError("Input tensor not ranked");

auto builder = mlir::Builder(getContext());

// Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape();

// Kernel shape.
auto kernelShape = kernel_shape();
Expand All @@ -2075,14 +2068,15 @@ LogicalResult ONNXMaxPoolSingleOutOp::inferShapes(
// Ceil mode.
auto ceilMode = ceil_mode();

// Infer shape for the output.
ONNXMaxPoolSingleOutOpAdaptor operandAdaptor(*this);
ONNXPoolOpShapeHelper<ONNXMaxPoolSingleOutOp, ONNXMaxPoolSingleOutOpAdaptor>
shapeHelper(this);
if (failed(shapeHelper.Compute(operandAdaptor, kernelShape, padsOpt,
stridesOpt, dilationsOpt, ceilMode)))
return emitError("Failed to scan MaxPool parameters successfully");
SmallVector<int64_t, 4> outputDims;
// Insert batch size.
outputDims.emplace_back(xShape[0]);
outputDims.emplace_back(xShape[1]);
// Compute and insert spatial dims.
insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
stridesOpt, dilationsOpt, ceilMode);

IndexExpr::getShape(shapeHelper.dimsForOutput(0), outputDims);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return success();
}
Expand Down
Loading

0 comments on commit 4b6b99d

Please sign in to comment.