Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse tensor.pad with producers. #9194

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ static llvm::cl::opt<bool> clEnableMultiResultDispatches(
"Enable dispatch region formation to enable multi-result dispatches"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clEnableFusePaddingIntoConsumerOps(
"iree-flow-enable-fuse-padding-into-consumer-ops",
llvm::cl::desc("Enable fusing linalg pad_tensor ops into consumer ops"),
llvm::cl::init(false));

static const char kRootOpAttr[] = "__root_op__";
static const char kFusionGroupsAttr[] = "__fused_op__";

Expand Down Expand Up @@ -151,8 +156,7 @@ static bool isRootOp(Operation *op) {
}
return !isa<linalg::FillOp>(op);
}
return isa<IREE::LinalgExt::TiledOpInterface>(op) &&
!isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(op);
return isa<IREE::LinalgExt::TiledOpInterface>(op);
}

/// Operations that are cloned into dispatch regions formed with other
Expand All @@ -162,7 +166,7 @@ static bool isClonableIntoDispatchOp(Operation *op) {
// trivially clonable too, but they cause problems
// with bufferization. Make them clonable when fixed.
if (isa<arith::IndexCastOp, linalg::InitTensorOp, tensor::CastOp,
tensor::ExtractOp, tensor::ExtractSliceOp, tensor::PadOp>(op)) {
tensor::ExtractOp, tensor::ExtractSliceOp>(op)) {
return true;
}
if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
Expand Down Expand Up @@ -420,14 +424,11 @@ buildOperandLessFlowDispatchWorkgroupOp(PatternRewriter &rewriter, Location loc,
}

/// Returns the list of operations that are to be cloned into the dispatch
/// based on the root operation.
/// with `rootOp` being the root, and belonging to `groupNum`.
static SmallVector<Operation *> getOperationsToMoveIntoDispatch(
Operation *rootOp) {
Operation *rootOp, unsigned groupNum) {
SmallVector<Operation *> dispatchOps;
dispatchOps.push_back(rootOp);
if (!hasRootOpAttribute(rootOp)) return dispatchOps;

int64_t groupNum = getRootNumber(rootOp);
std::deque<Operation *> worklist;
worklist.push_back(rootOp);
llvm::SmallDenseSet<Operation *, 2> movedOps;
Expand All @@ -448,6 +449,15 @@ static SmallVector<Operation *> getOperationsToMoveIntoDispatch(
return dispatchOps;
}

/// Returns the list of operations that are to be cloned into the dispatch
/// based on the root operation.
static SmallVector<Operation *> getOperationsToMoveIntoDispatch(
Operation *rootOp) {
if (!hasRootOpAttribute(rootOp)) return {rootOp};

return getOperationsToMoveIntoDispatch(rootOp, getRootNumber(rootOp));
}

//===---------------------------------------------------------------------===//
// Methods to legalize a dispatch region op, i.e. make it isolated from above.
//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -835,6 +845,89 @@ static LogicalResult legalizeDispatchWorkgroupOperands(
return success();
}

//===----------------------------------------------------------------------===//
// Pad op handling.
//===----------------------------------------------------------------------===//

/// Converts
///
/// ```mlir
/// %0 = tensor.pad %source low[%low0, %low1] high[%high0, %high1]
/// ```tensor.pad` operation into
///
/// ```mlir
/// %0 = linalg.fill ins(%pad_value:)
/// %1 = tensor.insert_slice %source into %0[%low0, %low1]
/// ```
///
/// and returns the `tensor.insert_slice`.
static FailureOr<tensor::InsertSliceOp> convertPadOpToFillAndInsertSlice(
OpBuilder &builder, tensor::PadOp padTensorOp) {
// Check that the region is just a yield operation which is returning a
// scalar that is not one of the arguments of the linalg operation.
Region &region = padTensorOp.region();
Block &block = region.front();
if (!llvm::hasSingleElement(block)) return failure();
auto yieldOp = cast<tensor::YieldOp>(block.getTerminator());
Value yieldVal = yieldOp.value();
if (llvm::any_of(block.getArguments(),
[&](Value v) { return v == yieldVal; })) {
return failure();
}

OpBuilder::InsertionGuard g(builder);
Location loc = padTensorOp.getLoc();
auto lowPad = padTensorOp.getMixedLowPad();
auto highPad = padTensorOp.getMixedHighPad();
Value source = padTensorOp.source();
RankedTensorType sourceType = padTensorOp.getSourceType();
int64_t rank = sourceType.getRank();

// Use the `ReifyRankedShapedTypeOpInterface` to get the output shape.
ReifiedRankedShapedTypeDims outputShape;
if (failed(cast<ReifyRankedShapedTypeOpInterface>(padTensorOp.getOperation())
.reifyResultShapes(builder, outputShape))) {
return failure();
}
if (outputShape.size() != 1) return failure();
SmallVector<Value> sourceShape = llvm::to_vector(
llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
return builder.create<tensor::DimOp>(loc, source, dim);
}));

Value initTensor = builder.create<linalg::InitTensorOp>(
loc, outputShape[0], sourceType.getElementType());
Value fill =
builder.create<linalg::FillOp>(loc, yieldVal, initTensor).getResult(0);
SmallVector<OpFoldResult> strides(rank, builder.getI64IntegerAttr(1));
return builder.create<tensor::InsertSliceOp>(
loc, source, fill, lowPad, getAsOpFoldResult(sourceShape), strides);
}

namespace {
/// Wraps the converstion of `tensor.pad` -> `linalg.fill ->
/// tensor.insert_slice` in a pattern.
struct TensorPadOpConversion : public OpRewritePattern<tensor::PadOp> {
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
if (!hasComputeUsesOutsideDispatch(padOp)) return failure();
if (padOp->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
return failure();
}

FailureOr<tensor::InsertSliceOp> insertSliceOp =
convertPadOpToFillAndInsertSlice(rewriter, padOp);
if (failed(insertSliceOp)) return failure();

rewriter.replaceOp(padOp, insertSliceOp->getResult());
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
// Pattern that create the dispatch region.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -887,21 +980,91 @@ struct CreateDispatchRegionOp : Base<OpType> {
private:
linalg::LinalgTransformationFilter transformationFilter;
};

/// Pattern to create a dispatch workgroup operation for `tensor.pad`
/// operations. Breaks the `tensor.pad` into `linalg.fill` ->
/// `tensor.insert_slice`. The `linalg.fill` is left outside the dispatch and
/// the `tensor.insert_slice` is used as the root of the dispatch.
struct CreateDispatchOpFromPadRootOp : OpRewritePattern<tensor::PadOp> {
CreateDispatchOpFromPadRootOp(
MLIRContext *context, const linalg::LinalgTransformationFilter &filter,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), transformationFilter(filter) {}

LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
// TODO(ravishankarm): It is getting strange to track when to apply this
// pattern and when not to. Need to revisit this, with dynamic shape cases
// in mind.
if (!hasComputeUsesOutsideDispatch(padOp)) return failure();
if (padOp->template getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
return failure();
}

if (failed(transformationFilter.checkAndNotify(rewriter, padOp))) {
return failure();
}

// Split the pad operation into fill + tensor.insert_slice
FailureOr<tensor::InsertSliceOp> insertSlice =
convertPadOpToFillAndInsertSlice(rewriter, padOp);
if (failed(insertSlice)) return failure();

// Get the workload to use for the dispatch.
FailureOr<SmallVector<Value>> workload =
getWorkloadForRootOp(rewriter, insertSlice.getValue());
if (failed(workload)) {
return failure();
}

SmallVector<Operation *> dispatchOps;
if (!hasRootOpAttribute(padOp)) {
dispatchOps.push_back(insertSlice.getValue());
} else {
dispatchOps = getOperationsToMoveIntoDispatch(insertSlice.getValue(),
getRootNumber(padOp));
}
rewriter.setInsertionPoint(insertSlice.getValue());
rewriter.replaceOp(padOp, insertSlice->getResult());

auto clonedOps = buildOperandLessFlowDispatchWorkgroupOp(
rewriter, insertSlice->getLoc(), workload.getValue(), dispatchOps);
if (failed(clonedOps)) {
return failure();
}
transformationFilter.replaceLinalgTransformationFilter(
rewriter, clonedOps.getValue()[0]);
return success();
}

private:
linalg::LinalgTransformationFilter transformationFilter;
};
} // namespace

//===----------------------------------------------------------------------===//
// Heuristics for fusing dispatchble ops with root ops using tile + fuse.
//===----------------------------------------------------------------------===//

/// Checks if the producer and consumer LinalgOps can be fused.
static bool areFusableLinalgOps(OpOperand &use) {
return areLinalgOpsFusableUsingTileAndFuse(use);
}

/// Returns true if this is a fusable use.
static bool isFusableWithConsumer(OpOperand &use) {
Operation *producer = use.get().getDefiningOp();
if (!producer) return false;

Operation *consumer = use.getOwner();

// Check for linalg producer -> consumer fusion with tile + fuse.
return areFusableLinalgOps(use);
if (isa<linalg::LinalgOp>(producer) && isa<linalg::LinalgOp>(consumer) &&
areLinalgOpsFusableUsingTileAndFuse(use)) {
return true;
}

// Can fuse with a `tensor.insert_slice` consumer if this is not part of
// a concat chain.
if (!clEnableFusePaddingIntoConsumerOps && isa<tensor::PadOp>(consumer)) {
return true;
}
return false;
}

/// For all uses of an operation, finds the use that dominates all other uses.
Expand Down Expand Up @@ -989,13 +1152,20 @@ static unsigned decideFusableLinalgOps(FunctionOpInterface funcOp,
unsigned newGroup = numRootOps++;
setRootAttribute(context, &op, newGroup);

linalg::OpOperandVector outOperands =
TypeSwitch<Operation *, linalg::OpOperandVector>(&op)
.Case<linalg::LinalgOp>([&](auto linalgOp) {
return linalgOp.getOutputTensorOperands();
})
.Default(
[&](Operation *) -> linalg::OpOperandVector { return {}; });
if (clEnableFusePaddingIntoConsumerOps) {
for (auto operand : op.getOperands()) {
Operation *definingOp = operand.getDefiningOp();
if (!definingOp) continue;
if (isa<tensor::PadOp>(definingOp)) {
appendToFusionGroup(definingOp, newGroup);
}
}
}

linalg::OpOperandVector outOperands;
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(&op)) {
outOperands = linalgOp.getOutputTensorOperands();
}
for (OpOperand *operand : outOperands) {
// Currently only fuse with producer ops that are `LinalgOp`s.
auto producer = operand->get().getDefiningOp<linalg::LinalgOp>();
Expand Down Expand Up @@ -1170,6 +1340,7 @@ void DispatchLinalgOnTensorsPass::runOnOperation() {
filterForComputeOps.setMatchByDefault();
RewritePatternSet computeOpDispatchPatterns(context);
computeOpDispatchPatterns.insert<
CreateDispatchOpFromPadRootOp,
CreateDispatchRegionOp<linalg::LinalgOp, OpInterfaceRewritePattern>,
CreateDispatchRegionOp<IREE::LinalgExt::TiledOpInterface,
OpInterfaceRewritePattern>,
Expand All @@ -1187,10 +1358,12 @@ void DispatchLinalgOnTensorsPass::runOnOperation() {
llvm::dbgs() << "\n\n";
});

/// Convert remaining ops to Flow ops.
/// Convert remaining ops to Flow ops. Also handle pad ops that arent
/// fused with other dispatches.
{
RewritePatternSet convertToFlowPatterns(context);
populateTensorToFlowConversionPatterns(context, convertToFlowPatterns);
convertToFlowPatterns.insert<TensorPadOpConversion>(context);
memref::populateResolveRankedShapeTypeResultDimsPatterns(
convertToFlowPatterns);
IREE::Flow::TensorReshapeOp::getCanonicalizationPatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ static llvm::cl::opt<bool> clEnablePaddingLinalgOps(
"flow-padding-size"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clEnableFusePaddingIntoConsumerOps(
"iree-flow-enable-fuse-padding-into-consumer-ops",
llvm::cl::desc("Enable fusing linalg pad_tensor ops into consumer ops"),
llvm::cl::init(false));

static llvm::cl::opt<int> clLinalgOpsPaddingSize(
"iree-flow-linalg-ops-padding-size",
llvm::cl::desc("Enable padding linalg ops to an integer multiple of "
Expand Down Expand Up @@ -189,10 +184,6 @@ void buildFlowTransformPassPipeline(OpPassManager &passManager,
buildGlobalOptimizationPassPipeline(passManager, transformOptions);

FunctionLikeNest(passManager)
// Pad tensors.
.addPredicatedPass((!clEnableFusePaddingIntoConsumerOps),
IREE::Flow::createPadTensorToSubTensorInsertPass)

// Preprocess the input to a form more amenable for fusion
// - Convert all elementwise ops to Linalg
// - Remove unit-extent dimensions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_lit_test_suite(
"dispatch_linalg_on_tensors_fusion.mlir",
"expand_tensor_shapes.mlir",
"export_benchmark_funcs.mlir",
"fuse_pad_with_consumer.mlir",
"infer_numeric_narrowing.mlir",
"initialize_empty_tensor.mlir",
"inject_dispatch_tracing.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_lit_test_suite(
"dispatch_linalg_on_tensors_fusion.mlir"
"expand_tensor_shapes.mlir"
"export_benchmark_funcs.mlir"
"fuse_pad_with_consumer.mlir"
"infer_numeric_narrowing.mlir"
"initialize_empty_tensor.mlir"
"inject_dispatch_tracing.mlir"
Expand Down
Loading