diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h deleted file mode 100644 index 1be406bf3adf92..00000000000000 --- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h +++ /dev/null @@ -1,34 +0,0 @@ -//===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Conversion patterns for decomposing types along call graph edges. That is, -// decomposing types for calls, returns, and function args. -// -// TODO: Make this handle dialect-defined functions, calls, and returns. -// Currently, the generic interfaces aren't sophisticated enough for the -// types of mutations that we are doing here. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H -#define MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H - -#include "mlir/Transforms/DialectConversion.h" -#include - -namespace mlir { - -/// Populates the patterns needed to drive the conversion process for -/// decomposing call graph types with the given `TypeConverter`. -void populateDecomposeCallGraphTypesPatterns(MLIRContext *context, - const TypeConverter &typeConverter, - RewritePatternSet &patterns); - -} // namespace mlir - -#endif // MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt index f8fb1f436a95b1..6384d25ee70273 100644 --- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_dialect_library(MLIRFuncTransforms - DecomposeCallGraphTypes.cpp DuplicateFunctionElimination.cpp FuncConversions.cpp OneToNFuncConversions.cpp diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp deleted file mode 100644 index 03be00328bda33..00000000000000 --- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp +++ /dev/null @@ -1,136 +0,0 @@ -//===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" - -using namespace mlir; -using namespace mlir::func; - -//===----------------------------------------------------------------------===// -// DecomposeCallGraphTypesForFuncArgs -//===----------------------------------------------------------------------===// - -namespace { -/// Expand function arguments according to the provided TypeConverter. -struct DecomposeCallGraphTypesForFuncArgs - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - auto functionType = op.getFunctionType(); - - // Convert function arguments using the provided TypeConverter. - TypeConverter::SignatureConversion conversion(functionType.getNumInputs()); - for (const auto &argType : llvm::enumerate(functionType.getInputs())) { - SmallVector decomposedTypes; - if (failed(typeConverter->convertType(argType.value(), decomposedTypes))) - return failure(); - if (!decomposedTypes.empty()) - conversion.addInputs(argType.index(), decomposedTypes); - } - - // If the SignatureConversion doesn't apply, bail out. - if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(), - &conversion))) - return failure(); - - // Update the signature of the function. - SmallVector newResultTypes; - if (failed(typeConverter->convertTypes(functionType.getResults(), - newResultTypes))) - return failure(); - rewriter.modifyOpInPlace(op, [&] { - op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), - newResultTypes)); - }); - return success(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// DecomposeCallGraphTypesForReturnOp -//===----------------------------------------------------------------------===// - -namespace { -/// Expand return operands according to the provided TypeConverter. -struct DecomposeCallGraphTypesForReturnOp - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - SmallVector newOperands; - for (ValueRange operand : adaptor.getOperands()) - llvm::append_range(newOperands, operand); - rewriter.replaceOpWithNewOp(op, newOperands); - return success(); - } -}; -} // namespace - -//===----------------------------------------------------------------------===// -// DecomposeCallGraphTypesForCallOp -//===----------------------------------------------------------------------===// - -namespace { -/// Expand call op operands and results according to the provided TypeConverter. -struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - - // Create the operands list of the new `CallOp`. - SmallVector newOperands; - for (ValueRange operand : adaptor.getOperands()) - llvm::append_range(newOperands, operand); - - // Create the new result types for the new `CallOp` and track the number of - // replacement types for each original op result. - SmallVector newResultTypes; - SmallVector expandedResultSizes; - for (Type resultType : op.getResultTypes()) { - unsigned oldSize = newResultTypes.size(); - if (failed(typeConverter->convertType(resultType, newResultTypes))) - return failure(); - expandedResultSizes.push_back(newResultTypes.size() - oldSize); - } - - CallOp newCallOp = rewriter.create(op.getLoc(), op.getCalleeAttr(), - newResultTypes, newOperands); - - // Build a replacement value for each result to replace its uses. - SmallVector replacedValues; - replacedValues.reserve(op.getNumResults()); - unsigned startIdx = 0; - for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) { - ValueRange repl = - newCallOp.getResults().slice(startIdx, expandedResultSizes[i]); - replacedValues.push_back(repl); - startIdx += expandedResultSizes[i]; - } - rewriter.replaceOpWithMultiple(op, replacedValues); - return success(); - } -}; -} // namespace - -void mlir::populateDecomposeCallGraphTypesPatterns( - MLIRContext *context, const TypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns - .add(typeConverter, context); -} diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp index 9e7759bef6d8fd..a3638c8766a5c6 100644 --- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp @@ -124,12 +124,10 @@ class ReturnOpTypeConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - // For a return, all operands go to the results of the parent, so - // rewrite them all. - rewriter.modifyOpInPlace(op, - [&] { op->setOperands(adaptor.getOperands()); }); + rewriter.replaceOpWithNewOp(op, + flattenValues(adaptor.getOperands())); return success(); } }; diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index de511c58ae6ee0..09c5b4b2a0ad50 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -9,7 +9,7 @@ #include "TestDialect.h" #include "TestOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -142,7 +142,10 @@ struct TestDecomposeCallGraphTypes typeConverter.addArgumentMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildDecomposeTuple); - populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + populateReturnOpTypeConversionPattern(patterns, typeConverter); + populateCallOpTypeConversionPattern(patterns, typeConverter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) return signalPassFailure();