Skip to content

Commit

Permalink
Revert "[mlir][Transforms] Add 1:N matchAndRewrite overload (llvm#1…
Browse files Browse the repository at this point in the history
…16470)"

This reverts commit 9df63b2.
  • Loading branch information
nithinsubbiah committed Dec 10, 2024
1 parent e9c68c6 commit 967e3a9
Show file tree
Hide file tree
Showing 12 changed files with 455 additions and 494 deletions.
35 changes: 4 additions & 31 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;

explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
Expand All @@ -155,29 +153,17 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
return matchAndRewrite(cast<SourceOp>(op),
OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
Expand All @@ -189,12 +175,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -203,13 +183,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
rewrite(op, adaptor, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}

private:
using ConvertToLLVMPattern::match;
Expand Down
69 changes: 0 additions & 69 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -538,15 +538,8 @@ class ConversionPattern : public RewritePattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports only 1:1 replacements. The 1:N overload is called
/// by the driver. By default, it calls this 1:1 overload or reports a fatal
/// error if 1:N replacements were found.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -556,14 +549,6 @@ class ConversionPattern : public RewritePattern {
return success();
}

/// Hook for derived classes to implement combined matching and rewriting.
/// This overload supports 1:N replacements.
virtual LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

/// Attempt to match and rewrite the IR root at the specified operation.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const final;
Expand All @@ -590,15 +575,6 @@ class ConversionPattern : public RewritePattern {
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}

/// Given an array of value ranges, which are the inputs to a 1:N adaptor,
/// try to extract the single value of each range to construct a the inputs
/// for a 1:1 adaptor.
///
/// This function produces a fatal error if at least one range has 0 or
/// more than 1 value: "pattern 'name' does not support 1:N conversion"
SmallVector<Value>
getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;

protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
Expand All @@ -614,8 +590,6 @@ template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;

OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
Expand All @@ -634,24 +608,12 @@ class OpConversionPattern : public ConversionPattern {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
Expand All @@ -662,12 +624,6 @@ class OpConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -676,13 +632,6 @@ class OpConversionPattern : public ConversionPattern {
rewrite(op, adaptor, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
SmallVector<Value> oneToOneOperands =
getOneToOneAdaptorOperands(adaptor.getOperands());
return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
}

private:
using ConversionPattern::matchAndRewrite;
Expand All @@ -708,31 +657,18 @@ class OpInterfaceConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}

/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
virtual void rewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -741,11 +677,6 @@ class OpInterfaceConversionPattern : public ConversionPattern {
rewrite(op, operands, rewriter);
return success();
}
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const {
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
}

private:
using ConversionPattern::matchAndRewrite;
Expand Down
Loading

0 comments on commit 967e3a9

Please sign in to comment.