Skip to content

Commit

Permalink
Iteratively run the main simplification pipeline.
Browse files Browse the repository at this point in the history
This introduces a new pass LowerToBackendContract (better name very
welcome) which performs the bulk of the simplifications that we do,
such as
- shape refinement
- dtype refinement
- maximizing value semantics
- inlining global slots
- decomposing complex ops

The key difference from before is that it iterates the set of
transformations, which can help to break a number of "catch-22" issues
where one simplification depends on another, the latest example being
here:
#1131

This also exposed that RefineTypes was sometimes crashing/asserting for
certain inputs. This commit hardens it a bit.
  • Loading branch information
silvasean committed Aug 17, 2022
1 parent 9c8b962 commit 57681f7
Show file tree
Hide file tree
Showing 14 changed files with 518 additions and 251 deletions.
34 changes: 23 additions & 11 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,21 @@ createPrepareForGlobalizeObjectGraphPass();

struct TorchLoweringPipelineOptions
: public PassPipelineOptions<TorchLoweringPipelineOptions> {
// If this option is true, then perform optimizations.
// If this option is false, only do the bare minimum for correctness.
Option<bool> optimize{*this, "optimize", llvm::cl::desc("Do optimizations."),
llvm::cl::init(true)};

// The maximum number of invocations of the simplification pipeline in
// LowerToBackendContract.
Option<int> maxIterations{
*this, "max-iterations",
llvm::cl::desc(
"Maximum number of invocations of the simplification pipeline."),
llvm::cl::init(10)};
// If this option is false, decompose complex operations.
// If this option is true, skip decomposition of complex operations.
Option<bool> decompose{*this, "decompose-complex-ops", llvm::cl::desc("Decompose complex operations."),
llvm::cl::init(true)};
// TODO: This should be replaced with a list of operations to decompose.
// (or some other way to specify the set of allowed ops in the backend
// contract)
Option<bool> decompose{*this, "decompose-complex-ops",
llvm::cl::desc("Decompose complex operations."),
llvm::cl::init(true)};
};

/// Creates a pipeline that lowers the object graph IR that is produced by
Expand All @@ -50,10 +56,16 @@ void createTorchScriptModuleToTorchBackendPipeline(
void createTorchFunctionToTorchBackendPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);

/// Creates a pipeline that refines shapes of tensor operations in the program.
void createTorchShapeRefinementPipeline(
/// Creates a pipeline that simplifies the computations in the program.
/// This pass does not do any global program restructuring -- it works entirely
/// within a single semantic model of a `builtin.module` with
/// `torch.global_slot` ops and `func.func` ops.
void createTorchSimplificationPipeline(
OpPassManager &pm, const TorchLoweringPipelineOptions &options);

/// Creates a pipeline that refines shapes of tensor operations in the program.
void createTorchShapeRefinementPipeline(OpPassManager &pm);

std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();

std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
Expand All @@ -78,10 +90,10 @@ createSimplifyShapeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();

std::unique_ptr<OperationPass<ModuleOp>>
createVerifyConversionToValueSemanticsPass();
createEraseModuleInitializerPass();

std::unique_ptr<OperationPass<ModuleOp>>
createEraseModuleInitializerPass();
createLowerToBackendContractPass(int maxIterations, bool decompose);

StringRef getShapeLibrary();

Expand Down
71 changes: 57 additions & 14 deletions include/torch-mlir/Dialect/Torch/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,6 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"
}];
}

def VerifyConversionToValueSemantics
: Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> {
let summary = "Verify that all tensors have been converted to value semantics";
let constructor =
"mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()";
let description = [{
Prior passes in the pipeline may have missed converting all tensors to value
semantics and we wish to catch such failures early instead of fixing
individual cases downstream.
}];
}

def EraseModuleInitializer
: Pass<"torch-erase-module-initializer", "ModuleOp"> {
let summary = "Erase the `torch.global_slot.module_initializer` op.";
Expand All @@ -273,9 +261,64 @@ def EraseModuleInitializer
let description = [{
Backends cannot currently handle module initializers, so we omit them from
our backend contract. This pass removes the
`torch.global_slot.module_initializer` op from the module if legal, or
raises an error.
`torch.global_slot.module_initializer` op from the module if legal.
}];
}

def LowerToBackendContract
: Pass<"torch-lower-to-backend-contract", "ModuleOp"> {
let summary = "Perform simplifications until the backend contract is satisfied.";
let constructor = [{
mlir::torch::Torch::createLowerToBackendContractPass(
/*maxIterations=*/10, /*decompose=*/true)
}];
let description = [{
This pass performs the bulk of the lowering of the program's computations
to the backend contract. This pass does not do any global program
restructuring -- it works entirely within a single semantic model
of a `builtin.module` with `torch.global_slot` ops and `func.func` ops.

This pass runs a set of simplifications within that semantic model until
the backend contract is satisfied, and fails if it cannot be satisfied.
In particular, the backend contract consists of:
- Tensors
- Have been converted to value semantics.
- Have at least a known rank, though ideally a maximally inferred shape.
- Have a known dtype.
- `torch.global_slot`'s have been eliminated from the program.
- Ops have been decomposed.

This particular choice of backend contract was born out of a common set of
requirements from backends, along with aligning with long-term PyTorch
direction of being more tracing-based. The set of simplifications performed
here can be thought of as simulating the kinds of simplifications that
happen naturally as part of tracing, but in a way that is applicable
to our TorchScript frontend. For the LazyTensorCore frontend, the backend
contract trivially holds (except for certain decompositions).

Generally it is not desirable to have a compiler where successful
compilation depends on "optimizing hard enough", but in this case, there
seems to be enough alignment and recognition in the industry that the
Python-based programming model in the source program is too dynamic
to feasibly handle in totality without a tracing approach that has access
to the source program to re-trace in the face of dynamism (e.g. the ability
to do what TorchDynamo calls "graph break"). We are attempting to maintain
a practical compiler that works well given the current set of constraints
of the TorchScript frontend that PyTorch provides us, and are working to
co-design PyTorch's direction so that we land in a place where most of this
"optimizing hard enough" is not necessary.
}];
let options = [
Option<"maxIterations", "max-iterations", "int", /*default=*/"10",
"Maximum number of invocations of the simplification pipeline.">,
// TODO: Make this a configurable set of ops.
Option<"decompose", "decompose", "bool", /*default=*/"true",
"Decompose ops.">

];
// TODO: Debug why this is needed, even though the input program has func.func
// ops in it.
let dependentDialects = ["func::FuncDialect"];
}

#endif // TORCHMLIR_TORCH_PASSES
2 changes: 1 addition & 1 deletion lib/Dialect/Torch/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_mlir_library(TorchMLIRTorchPasses
Passes.cpp
GlobalizeObjectGraph.cpp
InlineGlobalSlots.cpp
LowerToBackendContract.cpp
MaximizeValueSemantics.cpp
PrepareForGlobalizeObjectGraph.cpp
ReduceOpVariants.cpp
Expand All @@ -14,7 +15,6 @@ add_mlir_library(TorchMLIRTorchPasses
ReifyShapeCalculations.cpp
ShapeLibrary.cpp
SimplifyShapeCalculations.cpp
VerifyConversionToValueSemantics.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms
Expand Down
17 changes: 7 additions & 10 deletions lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,15 @@ namespace {
class EraseModuleInitializerPass
: public EraseModuleInitializerBase<EraseModuleInitializerPass> {
void runOnOperation() override {
auto walkResult = getOperation().walk([](GlobalSlotModuleInitializerOp op) {
for (auto initializer :
getOperation().getOps<GlobalSlotModuleInitializerOp>()) {
auto intialize =
cast<InitializeGlobalSlotsOp>(op.getBody()->getTerminator());
if (intialize.getNumOperands() != 0) {
op.emitError("could not erase non-empty module initializer");
return WalkResult::interrupt();
cast<InitializeGlobalSlotsOp>(initializer.getBody()->getTerminator());
if (intialize.getNumOperands() == 0) {
initializer.erase();
}
op.erase();
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) {
return signalPassFailure();
// The verifier ensures there is only one GlobalSlotModuleInitializerOp.
break;
}
}
};
Expand Down
Loading

0 comments on commit 57681f7

Please sign in to comment.