Skip to content

Commit

Permalink
Create ops with static shapes in PadTensorOpConversion if they're sta…
Browse files Browse the repository at this point in the history
…tic (#5516)

Currently, the mhlo.pad will be lowered to linalg.pad_tensor and then
lowered to `linalg.init_tensor + linalg.fill + subtensor_insert`. The
init_tensor op will produce a dynamic shape even if the shape is static.
This leads a `tensor.cast` op added and relies on further patterns to
fix them. This is not needed to static shape and will hit some issues
related to #5385.
  • Loading branch information
hanhanW authored Apr 19, 2021
1 parent 840f4c3 commit 131f2e2
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct PadTensorOpConversion : public OpConversionPattern<linalg::PadTensorOp> {

// TODO(ravishankarm): Use shape inference interface to get this.
SmallVector<OpFoldResult> sourceShape;
SmallVector<Value> outputShape;
SmallVector<OpFoldResult> outputShape;
for (int64_t dim : llvm::seq<int64_t>(0, rank)) {
SmallVector<Value> mapValues;
Value sourceDim = rewriter.createOrFold<memref::DimOp>(loc, source, dim);
Expand All @@ -106,21 +106,21 @@ struct PadTensorOpConversion : public OpConversionPattern<linalg::PadTensorOp> {
};
expr = addValueOrAttr(expr, lowPad[dim]);
expr = addValueOrAttr(expr, highPad[dim]);
outputShape.push_back(linalg::applyMapToValues(
rewriter, loc, AffineMap::get(1, numSymbols, expr), mapValues)[0]);
Value v = linalg::applyMapToValues(
rewriter, loc, AffineMap::get(1, numSymbols, expr), mapValues)[0];
if (auto cst = v.getDefiningOp<ConstantOp>()) {
outputShape.push_back(cst.value());
} else {
outputShape.push_back(v);
}
}
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, outputShape, sourceType.getElementType());
Value fill =
rewriter.create<linalg::FillOp>(loc, initTensor, yieldVal).getResult(0);
SmallVector<OpFoldResult> strides(rank, rewriter.getI64IntegerAttr(1));
Value replacement = rewriter.create<SubTensorInsertOp>(
loc, source, fill, lowPad, sourceShape, strides);
if (padTensorOp.getResultType() != replacement.getType()) {
replacement = rewriter.create<tensor::CastOp>(
loc, padTensorOp.getResultType(), replacement);
}
rewriter.replaceOp(padTensorOp, replacement);
rewriter.replaceOpWithNewOp<SubTensorInsertOp>(
padTensorOp, source, fill, lowPad, sourceShape, strides);
return success();
}
};
Expand Down

0 comments on commit 131f2e2

Please sign in to comment.