Skip to content

Commit

Permalink
Add RecomposeComplexOps declaration + fix typos in pass name (#1950)
Browse files Browse the repository at this point in the history
The `RecomposeComplexOps` pass currently does not have a TableGen
declaration and it is using the base class of `DecomposeComplexOps`,
which causes `--mlir-print-ir-after-all` to create wrong pass
labels. This commit fixes that as well as some minor typos in the name
of the pass.
  • Loading branch information
ramiro050 authored Mar 28, 2023
1 parent d803ab4 commit 0103c55
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 7 deletions.
2 changes: 1 addition & 1 deletion include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createRefinePublicReturnPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createDecomposeComplexOpsPass(ArrayRef<std::string> legalOps);

std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOps();
std::unique_ptr<OperationPass<func::FuncOp>> createRecomposeComplexOpsPass();

std::unique_ptr<OperationPass<ModuleOp>>
createReifyShapeCalculationsPass(StringRef extraLibrary);
Expand Down
23 changes: 23 additions & 0 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,29 @@ def DecomposeComplexOps : Pass<"torch-decompose-complex-ops", "func::FuncOp"> {
}];
}

def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> {
let summary = "Recompose torch operations that have been decomposed by TorchScript";
let constructor = "mlir::torch::Torch::createRecomposeComplexOpsPass()";
let description = [{
There are certain ops that TorchScript will split into multiple ops that
prevent optimizations in Torch-MLIR from taking place. In this pass such
sequences of ops are identified and combined into a higher level op,
preserving the original behavior, while allowing new optimizations to happen.

An example is the handling of the indexing operation in PyTorch. The following

```
input_tensor[1:2, :] = 7
```

will get split into a series of `slice` ops to get the sub-tensor, then an
in-place copy to overwrite the sub-tensor with the value 7. This type of
pattern prevents the `MaximizeValueSemantics` pass from succeeding. So,
using `RecomposeComplexOps`, the series of slices + copy is identified
and turned into a single `index_put` operation.
}];
}

def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> {
let summary = "Reify shape calculations.";
let constructor = [{
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
// Clean up again to avoid needing to to back around the fixed-point
// iteration.
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOps());
pm.addNestedPass<func::FuncOp>(createRecomposeComplexOpsPass());
// Reduce variants of ops to a smaller set of primitives.
pm.addNestedPass<func::FuncOp>(
createReduceOpVariantsPass(options.extraLibrary));
Expand Down
9 changes: 4 additions & 5 deletions lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ class RecomposeSliceCopy_ : public OpRewritePattern<AtenCopy_Op> {
} // namespace

namespace {
class RecomposeComplexOps
: public DecomposeComplexOpsBase<RecomposeComplexOps> {
class RecomposeComplexOpsPass
: public RecomposeComplexOpsBase<RecomposeComplexOpsPass> {
public:
RecomposeComplexOps() = default;
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
Expand All @@ -98,6 +97,6 @@ class RecomposeComplexOps
} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::torch::Torch::createRecomposeComplexOps() {
return std::make_unique<RecomposeComplexOps>();
mlir::torch::Torch::createRecomposeComplexOpsPass() {
return std::make_unique<RecomposeComplexOpsPass>();
}

0 comments on commit 0103c55

Please sign in to comment.