Skip to content

Commit

Permalink
[mlir] specify the values when notifying about op replacement
Browse files Browse the repository at this point in the history
It is useful for PatternRewriter listeners to know the values that are
replacing the op in addition to only the fact of the op being replaced
for being able to keep track of changes or for debugging.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D134748
  • Loading branch information
ftynse committed Sep 27, 2022
1 parent 401481d commit e8aaf75
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
6 changes: 3 additions & 3 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,9 @@ class RewriterBase : public OpBuilder, public OpBuilder::Listener {
/// they would like to be notified about certain types of mutations.

/// Notify the rewriter that the specified operation is about to be replaced
/// with another set of operations. This is called before the uses of the
/// operation have been changed.
virtual void notifyRootReplaced(Operation *op) {}
/// with the set of values potentially produced by new operations. This is
/// called before the uses of the operation have been changed.
virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {}

/// This is called on an operation that a rewrite is removing, right before
/// the operation is deleted. At this point, the operation has zero uses.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ void RewriterBase::replaceOpWithIf(
"incorrect number of values to replace operation");

// Notify the rewriter subclass that we're about to replace this root.
notifyRootReplaced(op);
notifyRootReplaced(op, newValues);

// Replace each use of the results when the functor is true.
bool replacedAllUses = true;
Expand Down Expand Up @@ -244,7 +244,7 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
/// the operation.
void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
// Notify the rewriter subclass that we're about to replace this root.
notifyRootReplaced(op);
notifyRootReplaced(op, newValues);

assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
// When the root of a pattern is about to be replaced, it can trigger
// simplifications to its users - make sure to add them to the worklist
// before the root is changed.
void notifyRootReplaced(Operation *op) override;
void notifyRootReplaced(Operation *op, ValueRange replacement) override;

/// PatternRewriter hook for erasing a dead operation.
void eraseOp(Operation *op) override;
Expand Down Expand Up @@ -348,7 +348,8 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
});
}

void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) {
void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
ValueRange replacement) {
LLVM_DEBUG({
logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
<< ")\n";
Expand Down Expand Up @@ -437,7 +438,7 @@ class OpPatternRewriteDriver : public PatternRewriter {

// When a root is going to be replaced, its removal will be notified as well.
// So there is nothing to do here.
void notifyRootReplaced(Operation *op) override {}
void notifyRootReplaced(Operation *op, ValueRange replacement) override {}

private:
/// The low-level pattern applicator.
Expand Down

0 comments on commit e8aaf75

Please sign in to comment.