Skip to content

Commit

Permalink
Revert "[mlir][Func] Delete DecomposeCallGraphTypes.cpp (llvm#117424)"
Browse files Browse the repository at this point in the history
This reverts commit 7267c85.

Signed-off-by: nithinsubbiah <[email protected]>
  • Loading branch information
nithinsubbiah committed Dec 5, 2024
1 parent f82c4b1 commit 6038573
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- DecomposeCallGraphTypes.h - CG type decompositions -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Conversion patterns for decomposing types along call graph edges. That is,
// decomposing types for calls, returns, and function args.
//
// TODO: Make this handle dialect-defined functions, calls, and returns.
// Currently, the generic interfaces aren't sophisticated enough for the
// types of mutations that we are doing here.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
#define MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H

#include "mlir/Transforms/DialectConversion.h"
#include <optional>

namespace mlir {

/// Populates the patterns needed to drive the conversion process for
/// decomposing call graph types with the given `TypeConverter`.
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
const TypeConverter &typeConverter,
RewritePatternSet &patterns);

} // namespace mlir

#endif // MLIR_DIALECT_FUNC_TRANSFORMS_DECOMPOSECALLGRAPHTYPES_H
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Func/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRFuncTransforms
DecomposeCallGraphTypes.cpp
DuplicateFunctionElimination.cpp
FuncConversions.cpp
OneToNFuncConversions.cpp
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
using OpConversionPattern<ReturnOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<ReturnOp>(op,
flattenValues(adaptor.getOperands()));
// For a return, all operands go to the results of the parent, so
// rewrite them all.
rewriter.modifyOpInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return success();
}
};
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2565,7 +2565,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
if (newMaterialization) {
assert(newMaterialization.getType() == outputType &&
"materialization callback produced value of incorrect type");
#endif // NDEBUG
rewriter.replaceOp(op, newMaterialization);
return success();
}
Expand Down
7 changes: 2 additions & 5 deletions mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "TestDialect.h"
#include "TestOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
Expand Down Expand Up @@ -142,10 +142,7 @@ struct TestDecomposeCallGraphTypes
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildDecomposeTuple);

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
populateCallOpTypeConversionPattern(patterns, typeConverter);
populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);

if (failed(applyPartialConversion(module, target, std::move(patterns))))
return signalPassFailure();
Expand Down

0 comments on commit 6038573

Please sign in to comment.