diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index c3c271e7badb..bfb4907d4679 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -28,15 +28,21 @@ createPrepareForGlobalizeObjectGraphPass(); struct TorchLoweringPipelineOptions : public PassPipelineOptions { - // If this option is true, then perform optimizations. - // If this option is false, only do the bare minimum for correctness. - Option optimize{*this, "optimize", llvm::cl::desc("Do optimizations."), - llvm::cl::init(true)}; - + // The maximum number of invocations of the simplification pipeline in + // LowerToBackendContract. + Option maxIterations{ + *this, "max-iterations", + llvm::cl::desc( + "Maximum number of invocations of the simplification pipeline."), + llvm::cl::init(10)}; // If this option is false, decompose complex operations. // If this option is true, skip decomposition of complex operations. - Option decompose{*this, "decompose-complex-ops", llvm::cl::desc("Decompose complex operations."), - llvm::cl::init(true)}; + // TODO: This should be replaced with a list of operations to decompose. + // (or some other way to specify the set of allowed ops in the backend + // contract) + Option decompose{*this, "decompose-complex-ops", + llvm::cl::desc("Decompose complex operations."), + llvm::cl::init(true)}; }; /// Creates a pipeline that lowers the object graph IR that is produced by @@ -50,10 +56,16 @@ void createTorchScriptModuleToTorchBackendPipeline( void createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); -/// Creates a pipeline that refines shapes of tensor operations in the program. -void createTorchShapeRefinementPipeline( +/// Creates a pipeline that simplifies the computations in the program. +/// This pass does not do any global program restructuring -- it works entirely +/// within a single semantic model of a `builtin.module` with +/// `torch.global_slot` ops and `func.func` ops. +void createTorchSimplificationPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options); +/// Creates a pipeline that refines shapes of tensor operations in the program. +void createTorchShapeRefinementPipeline(OpPassManager &pm); + std::unique_ptr> createAdjustCallingConventionsPass(); std::unique_ptr> createRefineTypesPass(); @@ -78,10 +90,10 @@ createSimplifyShapeCalculationsPass(); std::unique_ptr> createDropShapeCalculationsPass(); std::unique_ptr> -createVerifyConversionToValueSemanticsPass(); +createEraseModuleInitializerPass(); std::unique_ptr> -createEraseModuleInitializerPass(); +createLowerToBackendContractPass(int maxIterations, bool decompose); StringRef getShapeLibrary(); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 6ff61b088c31..93f70d9e5419 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -253,18 +253,6 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp" }]; } -def VerifyConversionToValueSemantics - : Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> { - let summary = "Verify that all tensors have been converted to value semantics"; - let constructor = - "mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()"; - let description = [{ - Prior passes in the pipeline may have missed converting all tensors to value - semantics and we wish to catch such failures early instead of fixing - individual cases downstream. - }]; -} - def EraseModuleInitializer : Pass<"torch-erase-module-initializer", "ModuleOp"> { let summary = "Erase the `torch.global_slot.module_initializer` op."; @@ -273,9 +261,64 @@ def EraseModuleInitializer let description = [{ Backends cannot currently handle module initializers, so we omit them from our backend contract. This pass removes the - `torch.global_slot.module_initializer` op from the module if legal, or - raises an error. + `torch.global_slot.module_initializer` op from the module if legal. + }]; +} + +def LowerToBackendContract + : Pass<"torch-lower-to-backend-contract", "ModuleOp"> { + let summary = "Perform simplifications until the backend contract is satisfied."; + let constructor = [{ + mlir::torch::Torch::createLowerToBackendContractPass( + /*maxIterations=*/10, /*decompose=*/true) + }]; + let description = [{ + This pass performs the bulk of the lowering of the program's computations + to the backend contract. This pass does not do any global program + restructuring -- it works entirely within a single semantic model + of a `builtin.module` with `torch.global_slot` ops and `func.func` ops. + + This pass runs a set of simplifications within that semantic model until + the backend contract is satisfied, and fails if it cannot be satisfied. + In particular, the backend contract consists of: + - Tensors + - Have been converted to value semantics. + - Have at least a known rank, though ideally a maximally inferred shape. + - Have a known dtype. + - `torch.global_slot`'s have been eliminated from the program. + - Ops have been decomposed. + + This particular choice of backend contract was born out of a common set of + requirements from backends, along with aligning with long-term PyTorch + direction of being more tracing-based. The set of simplifications performed + here can be thought of as simulating the kinds of simplifications that + happen naturally as part of tracing, but in a way that is applicable + to our TorchScript frontend. For the LazyTensorCore frontend, the backend + contract trivially holds (except for certain decompositions). + + Generally it is not desirable to have a compiler where successful + compilation depends on "optimizing hard enough", but in this case, there + seems to be enough alignment and recognition in the industry that the + Python-based programming model in the source program is too dynamic + to feasibly handle in totality without a tracing approach that has access + to the source program to re-trace in the face of dynamism (e.g. the ability + to do what TorchDynamo calls "graph break"). We are attempting to maintain + a practical compiler that works well given the current set of constraints + of the TorchScript frontend that PyTorch provides us, and are working to + co-design PyTorch's direction so that we land in a place where most of this + "optimizing hard enough" is not necessary. }]; + let options = [ + Option<"maxIterations", "max-iterations", "int", /*default=*/"10", + "Maximum number of invocations of the simplification pipeline.">, + // TODO: Make this a configurable set of ops. + Option<"decompose", "decompose", "bool", /*default=*/"true", + "Decompose ops."> + + ]; + // TODO: Debug why this is needed, even though the input program has func.func + // ops in it. + let dependentDialects = ["func::FuncDialect"]; } #endif // TORCHMLIR_TORCH_PASSES diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index 741901646b35..d9de61f0063d 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_library(TorchMLIRTorchPasses Passes.cpp GlobalizeObjectGraph.cpp InlineGlobalSlots.cpp + LowerToBackendContract.cpp MaximizeValueSemantics.cpp PrepareForGlobalizeObjectGraph.cpp ReduceOpVariants.cpp @@ -14,7 +15,6 @@ add_mlir_library(TorchMLIRTorchPasses ReifyShapeCalculations.cpp ShapeLibrary.cpp SimplifyShapeCalculations.cpp - VerifyConversionToValueSemantics.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms diff --git a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp index f51657a45e6a..450d84b22ed3 100644 --- a/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp +++ b/lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp @@ -27,18 +27,15 @@ namespace { class EraseModuleInitializerPass : public EraseModuleInitializerBase { void runOnOperation() override { - auto walkResult = getOperation().walk([](GlobalSlotModuleInitializerOp op) { + for (auto initializer : + getOperation().getOps()) { auto intialize = - cast(op.getBody()->getTerminator()); - if (intialize.getNumOperands() != 0) { - op.emitError("could not erase non-empty module initializer"); - return WalkResult::interrupt(); + cast(initializer.getBody()->getTerminator()); + if (intialize.getNumOperands() == 0) { + initializer.erase(); } - op.erase(); - return WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) { - return signalPassFailure(); + // The verifier ensures there is only one GlobalSlotModuleInitializerOp. + break; } } }; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp new file mode 100644 index 000000000000..14b40278bde4 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -0,0 +1,247 @@ +//===- LowerToBackendContract.cpp --------------------------------*- C++-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "torch-lower-to-backend-contract" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +//===----------------------------------------------------------------------===// +// Checking the backend contract. +//===----------------------------------------------------------------------===// + +static LogicalResult checkType(Operation *op, Type type, + bool actuallyEmitDiagnostics) { + // Allow various scalar types that backends are expected to be able to handle. + if (type.isa()) + return success(); + + // Backends are not expected to support dynamic computations on these types, + // but they frequently appear as parameters to ops which backends + // can statically pattern match and eliminate from the program. + // For example, a tensor operand might be optional, and the backend + // will pattern-match statically whether it is passed as a tensor or None. + if (type.isa()) + return success(); + + // We blanket prohibit non-value-semantic tensors. + // All of our backends are currently based on value-semantic tensors, so + // we consider it our responsibility to lower all non-value-semantic tensors + // to value-semantic tensors. + if (type.isa()) { + if (actuallyEmitDiagnostics) { + return op + ->emitError("unsupported by backend contract: non-value tensor type") + .attachNote() + .append("this is likely due to a missing case in the " + "MaximizeValueSemantics pass"); + } else { + return failure(); + } + } + + // For value-semantic tensors, we require at least a known rank and dtype. + // We are not aware of a situation where our backends can handle an unranked + // tensor type or a tensor with a dynamic dtype. + // + // There are somewhat fundamental reasons for this. In particular, the problem + // of unranked codegen is completely different from the problem of ranked + // codegen (since ranked corresponds to a fixed loop nest structure). For all + // codegen systems we are aware of, the program must be reduced to operate + // on ranked tensors at some point in compilation, and we are not aware of + // any backend with a general solution to this problem before it reaches + // codegen. So we consider it our responsibility to eliminate unranked tensor + // from the program. + // + // We aren't aware of any backend with any infrastructure to represent dynamic + // dtypes, let alone transform and optimize them. Additionally, it is unlikely + // that any backend, even if it supports dynamic dtypes in some form, will + // have an sufficiently rich system for representing PyTorch type promotion + // rules. So we consider it our responsibility to ensure that all dtypes are + // statically known. + if (auto tensorType = type.dyn_cast()) { + if (!tensorType.hasSizes()) { + if (actuallyEmitDiagnostics) { + return op + ->emitError( + "unsupported by backend contract: tensor with unknown rank") + .attachNote() + .append("this is likely due to a missing shape transfer function " + "in shape_lib_gen.py"); + } else { + return failure(); + } + } + if (!tensorType.hasDtype()) { + if (actuallyEmitDiagnostics) { + return op + ->emitError( + "unsupported by backend contract: tensor with unknown dtype") + .attachNote() + .append("this is likely due to a missing case in RefineTypes"); + } else { + return failure(); + } + } + return success(); + } + + // Optional types are also in the category of types which we don't expect + // backends to dynamically compute with, but they can be pattern matched + // in many cases that are practically necessary. + if (auto optionalType = type.dyn_cast()) { + // TODO: Be stricter about tensor types. + // See comment below for ListType. + if (optionalType.getContainedType().isa()) + return success(); + return checkType(op, optionalType.getContainedType(), + actuallyEmitDiagnostics); + } + // List types are also in the category of types which we don't expect + // backends to dynamically compute with, but they can be pattern matched + // in many cases that are practically necessary. For example, the + // strides of a convolution op are represented as a list. + if (auto listType = type.dyn_cast()) { + // TODO: Be stricter about tensor types. + // For the moment, there are cases (such as for torch.cat) where we end + // up with `!torch.list` which doesn't have shape or dtype in + // the contained type information. Somehow this slips through and works. + // We should be stricter about this and properly infer the contained type + // and shape. + if (listType.getContainedType().isa()) + return success(); + return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics); + } + // Tuple types are also in the category of types which we don't expect + // backends to dynamically compute with, but they can be pattern matched + // in many cases that are practically necessary. + if (auto tupleType = type.dyn_cast()) { + for (auto containedType : tupleType.getContainedTypes()) { + if (failed(checkType(op, containedType, actuallyEmitDiagnostics))) + return failure(); + } + return success(); + } + + // Unsupported type. + if (actuallyEmitDiagnostics) { + return op->emitError("unsupported by backend contract: type ") << type; + } else { + return failure(); + } +} + +static bool satisfiesBackendContract(ModuleOp module, + bool actuallyEmitDiagnostics = false) { + // We do not permit `torch.global_slot`'s in the backend contract, since + // support for them is not widespread, and this does not align with PyTorch's + // more tracing-based direction. + // + // We just check for the GlobalSlotModuleInitializerOp since its verifier + // ensures that the set of global slots matches those initialized by the + // module initializer. + auto walkResult0 = module.walk([&](Torch::GlobalSlotModuleInitializerOp op) { + if (actuallyEmitDiagnostics) { + // Report the error on the terminator to avoid dumping the whole + // initializer itself, which can have pages of ops in it. + op.getBody() + ->getTerminator() + ->emitError("unsupported by backend contract: module initializers") + .attachNote() + .append("this is likely due to InlineGlobalSlots being unable to " + "inline a global slot"); + } + return WalkResult::interrupt(); + }); + if (walkResult0.wasInterrupted()) + return false; + + // Check all the type of all Value's in the program. + // + // A pre-order walk gives a more intuitive "first error". + // TODO: Should we report more than the first error? + // How do we avoid making it too spammy? + auto walkResult1 = module.walk([&](Block *block) { + for (BlockArgument arg : block->getArguments()) + if (failed(checkType(block->getParentOp(), arg.getType(), + actuallyEmitDiagnostics))) { + return WalkResult::interrupt(); + } + for (Operation &op : *block) + for (OpResult result : op.getResults()) + if (failed(checkType(&op, result.getType(), actuallyEmitDiagnostics))) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + if (walkResult1.wasInterrupted()) + return false; + return true; +} + +namespace { +class LowerToBackendContractPass + : public LowerToBackendContractBase { +public: + LowerToBackendContractPass() = default; + LowerToBackendContractPass(int maxIterations, bool decompose) { + this->maxIterations = maxIterations; + this->decompose = decompose; + } + void runOnOperation() override { + ModuleOp module = getOperation(); + + OpPassManager pm(module.getOperationName()); + TorchLoweringPipelineOptions options; + options.decompose = decompose; + createTorchSimplificationPipeline(pm, options); + + int i = 0; + do { + if (i++ == maxIterations) { + LLVM_DEBUG({ + llvm::dbgs() << "LowerToBackendContractPass: " + << "failed to satisfy backend contract after " + << maxIterations + << " iterations of the simplification pipeline\n"; + }); + // Show the diagnostics. + (void)satisfiesBackendContract(module, + /*actuallyEmitDiagnostics=*/true); + return signalPassFailure(); + } + + if (failed(runPipeline(pm, module))) + return signalPassFailure(); + } while (!satisfiesBackendContract(module)); + LLVM_DEBUG({ + llvm::dbgs() << "LowerToBackendContractPass: " + << "succeeded after " << i + << " iterations of the simplification pipeline\n"; + }); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createLowerToBackendContractPass(int maxIterations, + bool decompose) { + return std::make_unique(maxIterations, decompose); +} diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index d89477d6d1f1..f3f6641ebd33 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -31,6 +31,10 @@ void mlir::torch::registerTorchPasses() { "Pipeline lowering a Torch function to Torch backend form.", mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline); mlir::PassPipelineRegistration( + "torch-simplification-pipeline", + "Pipeline simplifying computations in the program.", + mlir::torch::Torch::createTorchSimplificationPipeline); + mlir::PassPipelineRegistration<>( "torch-shape-refinement-pipeline", "Pipeline refining shapes of tensors.", mlir::torch::Torch::createTorchShapeRefinementPipeline); } @@ -66,131 +70,82 @@ void mlir::torch::Torch::createTorchScriptModuleToTorchBackendPipeline( void mlir::torch::Torch::createTorchFunctionToTorchBackendPipeline( OpPassManager &pm, const TorchLoweringPipelineOptions &options) { - // General considerations: As a matter of bring-up, we are simultaneously - // building out the frontend pipeline and also co-developing the backend - // support story as well. This means that sometimes the most expedient way to - // support a given program is to "optimize hard enough" that the parts of the - // program that touch unimplemented backend support go away (constant folded, - // dead-code-eliminated, etc.). In the fullness of time, most of that - // optimization should not be necessary, and we should have an "O0" pipeline - // that runs practically no optimizations. - // However, as a matter of expediency, at the moment we do run those - // optimizations. We guard those passes under the `options.optimize` option - // (which default to true, currently). We leave notes with the `OPT-ONLY` tag - // why we currently need that pass for correctness. - // We should eventually remove those passes from the default pipeline once - // backends have enough support. - // In particular the following features are needed in some form from backends: - // - Error handling (RaiseException + error string formatting) - // - First-class list type - // - torch.global_slot lowering - // - ... - // Please try to keep this list somewhat up to date when adding - // "optimize hard enough that it works" transformations. - // Incorporate user annotations and remove signature Python-isms. pm.addPass(createAdjustCallingConventionsPass()); + // Perform the bulk of lowering to the backend contract. + // See the pass documentation for more information. + pm.addPass(createLowerToBackendContractPass(options.maxIterations, + options.decompose)); +} - // TODO: Remove options.optimize and this OPT-ONLY stuff -- we are already way - // past the point of no return for it being necessary for functional - // correctness. - if (options.optimize) { - // Eliminate the PrimTupleIndexOp generated from the - // adjustCallingConventions - pm.addNestedPass(createCanonicalizerPass()); - // Inline global slots, which for most inference scenarios deletes them. - // This also exposes more information to intraprocedural transformations - // below like MaximizeValueSemantics and RefineTypes. - // OPT-ONLY: Don't rely on this pass to "lower" global slots by deleting. - // Also don't rely on this pass to expose constants into the program to - // simplify handling of "optional". - pm.addPass(createInlineGlobalSlotsPass()); - // After doing a first round of inlining global slots, canonicalize again to - // take advantage of optimization opportunities exposed by the inlined - // global slots. In particular, this is functionally necessary now because - // large amounts of control flow are guarded by an "is training" flag, so - // inlining removes certain mutating operations done on the slots enabling - // them to be deleted. - // TODO: In full generality, we need to do a fixed-point iteration of - // shape inference, maximizing value semantics, decomposition, inling global - // slots, and canonicalization. - pm.addNestedPass(createCanonicalizerPass()); - // Inline again, cleaning up any remaining global slots that might be dead - // now. - pm.addPass(createInlineGlobalSlotsPass()); - // Erase the module initializers (or fail compilation), since they aren't - // permitted in our backend contract at the moment. - pm.addPass(Torch::createEraseModuleInitializerPass()); - } - +// A simplification pipeline to establish the invariants of the backend +// contract (see `satisfiedBackendContract` in `LowerToBackendContract`). +// +// We structure this so that a single run of this pipeline is enough for +// most models, but it is possible for it to take multiple runs to fully +// clean things up when there are cyclic dependencies between certain +// simplifications, such as a decomposition relying on shape refinement which +// depends on another decomposition. +// +// Although technically this pipeline is an implementation detail of +// LowerToBackendContract, we expose it here to help debugging. +// +// LowerToBackendContract will run this pipeline as many times as necessary, but +// in general, it is costly to re-run this pipeline, since all the passes do +// O(module size) work. We want the number of iterations of this pipeline +// to be bounded by meaningful "always in practice small" program properties, +// such as loop nesting depth, number of sequentially dependent steps of +// constant global slots proving that other global slots are dead, etc. +// +// It is generally always possible to construct a pathological input that will +// exceed the number of iterations. If we do find practical cases with +// O(module size) number of iterations of this simplification pipeline, then +// we may need to adjust the approach, such as to do some of the transformations +// together at finer granularity. +void mlir::torch::Torch::createTorchSimplificationPipeline( + OpPassManager &pm, const TorchLoweringPipelineOptions &options) { + // General cleanup. + pm.addNestedPass(createCanonicalizerPass()); + // Inline global slots to expose a bunch of simplification opportunities + // from constant hyperparameters, weights, etc. + pm.addPass(createInlineGlobalSlotsPass()); + // Erase the module initializer if we have proven that all the global slots + // are gone. + pm.addPass(createEraseModuleInitializerPass()); + // Clean up again to avoid needing to to back around the fixed-point + // iteration. + pm.addNestedPass(createCanonicalizerPass()); // Reduce variants of ops to a smaller set of primitives. pm.addNestedPass(createReduceOpVariantsPass()); - - if (options.optimize) { - // OPT-ONLY: Right now we rely on this to eliminate certain branches that - // guard unreachable code that backends can't handle yet, such as lists, - // RaiseException, unimplemented tensor ops, and only-used-in-training - // operations on `torch.global_slot`'s. - pm.addNestedPass(createCanonicalizerPass()); - // OPT-ONLY: We may have deleted some `torch.global_slot.get` / - // `torch.global_slot.get` ops, which may have left more - // `torch.global_slot`'s unused. - pm.addPass(createSymbolDCEPass()); - } - - //===--------------------------------------------------------------------===// - // Lowering to ranked !torch.vtensors of known dtype. - //===--------------------------------------------------------------------===// - + pm.addNestedPass(createCanonicalizerPass()); + // Remove dead global slots. + pm.addPass(createSymbolDCEPass()); // Convert the bulk of non-ABI-visible !torch.tensor's to !torch.vtensor's. pm.addNestedPass(Torch::createMaximizeValueSemanticsPass()); - - // Update the return op to return value tensors and remove dead ops. + // Update the return op to return value tensors. pm.addPass(Torch::createRefinePublicReturnPass()); pm.addNestedPass(createCanonicalizerPass()); - - // Ensure that all tensors have been converted to value semantics. - pm.addPass(Torch::createVerifyConversionToValueSemanticsPass()); - // Do shape refinement. - // This must be run before RefineTypes (which primarily does dtype inference), - // because Torch type promotion rules actually depend on the shape of the - // operand. - createTorchShapeRefinementPipeline(pm, options); + // This should be run before RefineTypes (which primarily does dtype + // inference), because Torch type promotion rules actually depend on the shape + // of the operand. + createTorchShapeRefinementPipeline(pm); // Refine types in the program, which mainly means inferring dtypes of ops. pm.addNestedPass(Torch::createRefineTypesPass()); - // Propagate to ABI return types the shape/dtype information discovered by // the previous pass. Doing this is ABI-compatible for our backends. pm.addPass(Torch::createRefinePublicReturnPass()); - - if (options.optimize) { - // This can fold away some branches given the information got from - // RefineTypes before doing maximize value sematics which only works with - // basic blocks. - pm.addNestedPass(createCanonicalizerPass()); - } - - if (options.optimize) { - // All the type refinement we've done above has exposed new information - // that allows folding away more stuff. - // OPT-ONLY: Right now we rely on this to eliminate certain - // branches that guard unreachable code that backends can't handle yet, such - // as lists, RaiseException, unimplemented aten ops, and - // only-used-in-training operations on `torch.global_slot`'s. - pm.addNestedPass(createCanonicalizerPass()); - } - + // This can fold away some branches given the information got from + // RefineTypes before doing maximize value sematics which only works with + // basic blocks. + pm.addNestedPass(createCanonicalizerPass()); if (options.decompose) { pm.addNestedPass(Torch::createDecomposeComplexOpsPass()); pm.addNestedPass(createCanonicalizerPass()); } - - // TODO: VerifyTorchBackendContractPass. } -void mlir::torch::Torch::createTorchShapeRefinementPipeline( - OpPassManager &pm, const TorchLoweringPipelineOptions &options) { +void mlir::torch::Torch::createTorchShapeRefinementPipeline(OpPassManager &pm) { // Reify the shape functions for each op that is present in the shape library. pm.addPass(Torch::createReifyShapeCalculationsPass()); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index c5e865790d31..80ae4661431b 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1152,7 +1152,7 @@ void TypeAnalysis::visitAtenEmbeddingBagOp(Operation *op) { resultIntKnowledge.dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed); - for (int64_t i = 1; i < 4; i++) { + for (int64_t i = 1, e = op->getNumResults(); i < e; i++) { incorporateKnowledge(op->getResult(i), resultIntKnowledge); } return; @@ -1259,6 +1259,12 @@ void TypeAnalysis::visitAtenTensorOp(AtenTensorOp op) { while (auto listType = type.dyn_cast()) { type = listType.getContainedType(); } + // TODO: Support tensor as the contained type of the list. + // These are the only types handled by fillInDTypeGivenDTypeAndDataType below. + if (!type.isa()) { + incorporateKnowledge(op.getResult(), knowledge); + return; + } fillInDTypeGivenDTypeAndDataType(knowledge, dtype, type); incorporateKnowledge(op.getResult(), knowledge); } @@ -1418,13 +1424,13 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { }; if (auto tensorType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement) + if (!latticeElement || latticeElement->isUninitialized()) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); return getRefinedTensorType(tensorType, knowledge); } else if (auto optionalType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement) + if (!latticeElement || latticeElement->isUninitialized()) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); if (knowledge.optional == OptionalKnowledge::isNone) @@ -1438,7 +1444,7 @@ static Type getMostRefinedStaticType(Value v, DataFlowSolver &solver) { } } else if (auto scalarType = v.getType().dyn_cast()) { const ValueState *latticeElement = solver.lookupState(v); - if (!latticeElement) + if (!latticeElement || latticeElement->isUninitialized()) return nullptr; const ValueKnowledge &knowledge = latticeElement->getValue(); if (knowledge.kind == torch_upstream::TypeKind::IntType) diff --git a/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp b/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp deleted file mode 100644 index 7a055ece3401..000000000000 --- a/lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp +++ /dev/null @@ -1,56 +0,0 @@ -//===- VerifyConversionToValueSemantics.cpp ----------------------*- C++-*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" - -#include "mlir/IR/BuiltinOps.h" -#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" -#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" - -using namespace mlir; -using namespace mlir::torch::Torch; - -static LogicalResult checkValueType(Operation *op, Value value) { - auto isNotValueTensorType = value.getType().isa(); - return isNotValueTensorType - ? op->emitError( - "found a non-value tensor type, this is likely due to a " - "missing case in the MaximizeValueSemantics pass") - : success(); -} - -namespace { -class VerifyConversionToValueSemanticsPass - : public VerifyConversionToValueSemanticsBase< - VerifyConversionToValueSemanticsPass> { - void runOnOperation() override { - auto walkResult = getOperation().walk([&](Block *block) { - for (BlockArgument arg : block->getArguments()) - if (failed(checkValueType(block->getParentOp(), arg))) - return WalkResult::interrupt(); - - for (Operation &op : *block) - for (OpResult result : op.getResults()) - if (failed(checkValueType(&op, result))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }); - - if (walkResult.wasInterrupted()) - signalPassFailure(); - } -}; - -} // namespace - -std::unique_ptr> -mlir::torch::Torch::createVerifyConversionToValueSemanticsPass() { - return std::make_unique(); -} diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index c74cc742aaca..2aec9567d0ee 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -77,16 +77,14 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addNestedPass(createConvertTorchToArithPass()); pm.addNestedPass(memref::createExpandOpsPass()); - if (options.optimize) { - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // Resolve `dim` ops on tensors (which currently live in the `memref` - // dialect for some reason -- we don't have memrefs at this level). - pm.addNestedPass( - memref::createResolveShapedTypeResultDimsPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - } + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // Resolve `dim` ops on tensors (which currently live in the `memref` + // dialect for some reason -- we don't have memrefs at this level). + pm.addNestedPass( + memref::createResolveShapedTypeResultDimsPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); // Finish the type conversion from `torch` types to the types of the // linalg-on-tensors backend contract. @@ -111,12 +109,10 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( // Perform rank broadcasting so TosaToLinalg pass works pm.addNestedPass(createTosaMakeBroadcastablePass()); - if (options.optimize) { - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - } + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); // Finish the type conversion from `torch` types to the types of the // TOSA backend contract. @@ -140,21 +136,17 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline( pm.addNestedPass(createConvertTorchToMhloPass()); - if (options.optimize) { - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - } + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); // Convert CHLO ops to MHLO ops pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); - if (options.optimize) { - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - } + // Clean up any non-canonical code introduced above.. + pm.addNestedPass(createCanonicalizerPass()); + // The resolution of `dim` ops tends to create identical ops. CSE them. + pm.addNestedPass(createCSEPass()); // Finish the type conversion from `torch` types to the types of the // MHLO backend contract. @@ -162,4 +154,4 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline( pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); } -#endif \ No newline at end of file +#endif diff --git a/test/Dialect/Torch/erase-module-initializer.mlir b/test/Dialect/Torch/erase-module-initializer.mlir index 232de3e6f671..c0bbbdbbddb0 100644 --- a/test/Dialect/Torch/erase-module-initializer.mlir +++ b/test/Dialect/Torch/erase-module-initializer.mlir @@ -6,15 +6,3 @@ torch.global_slot.module_initializer { torch.initialize.global_slots [ ] } - -// ----- - -torch.global_slot @slot0 : !torch.int - -// expected-error@+1 {{could not erase non-empty module initializer}} -torch.global_slot.module_initializer { - %0 = torch.constant.int 0 - torch.initialize.global_slots [ - @slot0(%0: !torch.int) - ] -} diff --git a/test/Dialect/Torch/lower-to-backend-contract-error.mlir b/test/Dialect/Torch/lower-to-backend-contract-error.mlir new file mode 100644 index 000000000000..824f3ae23467 --- /dev/null +++ b/test/Dialect/Torch/lower-to-backend-contract-error.mlir @@ -0,0 +1,61 @@ +// RUN: torch-mlir-opt -torch-lower-to-backend-contract -split-input-file -verify-diagnostics %s + +torch.global_slot.module_initializer { + %0 = torch.constant.int 1 + // expected-error @+2 {{unsupported by backend contract: module initializers}} + // expected-note @+1 {{this is likely due to}} + torch.initialize.global_slots [ + @slot0(%0 : !torch.int) + ] +} +torch.global_slot @slot0 : !torch.int + + +// ----- + +// expected-error @+2 {{unsupported by backend contract: non-value tensor type}} +// expected-note @+1 {{this is likely due to}} +func.func @f(%arg0: !torch.tensor) { + return +} + +// ----- + +// expected-error @+2 {{unsupported by backend contract: tensor with unknown rank}} +// expected-note @+1 {{this is likely due to}} +func.func @f(%arg0: !torch.vtensor<*,f32>) { + return +} + +// ----- + +// expected-error @+2 {{unsupported by backend contract: tensor with unknown dtype}} +// expected-note @+1 {{this is likely due to}} +func.func @f(%arg0: !torch.vtensor<[],unk>) { + return +} + +// ----- + +// expected-error @+1 {{unsupported by backend contract: type '!torch.any'}} +func.func @f(%arg0: !torch.any) { + return +} + +// ----- + +// Test case: checking of op results. +// TODO: In theory we could diagnose every single value, but for now we bail out on the first one. + +func.func @f(%arg0: !torch.bool, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[7],f32>) -> !torch.vtensor<*,f32> { + // expected-error @+2 {{unsupported by backend contract: tensor with unknown rank}} + // expected-note @+1 {{this is likely due to}} + %0 = torch.prim.If %arg0 -> (!torch.vtensor<*,f32>) { + %1 = torch.tensor_static_info_cast %arg1 : !torch.vtensor<[],f32> to !torch.vtensor<*,f32> + torch.prim.If.yield %1 : !torch.vtensor<*,f32> + } else { + %2 = torch.tensor_static_info_cast %arg2 : !torch.vtensor<[7],f32> to !torch.vtensor<*,f32> + torch.prim.If.yield %2 : !torch.vtensor<*,f32> + } + return %0 : !torch.vtensor<*,f32> +} diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 38ce507fe785..b8ee124fb0e3 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -240,3 +240,35 @@ func.func @prim.dtype(%arg: !torch.vtensor<*,bf16>) -> !torch.vtensor<*,unk> { return %result2 : !torch.vtensor<*,unk> } + +// ----- + +// Check that we don't crash on this input. + +// CHECK-LABEL: func.func @forward +func.func @forward() -> !torch.vtensor { + %false = torch.constant.bool false + %none = torch.constant.none + %0 = torch.prim.ListConstruct : () -> !torch.list + // CHECK: torch.aten.tensor + %1 = torch.aten.tensor %0, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor + return %1 : !torch.vtensor +} + +// ----- + +// Check that we don't crash on this input. +// TODO: This appears to result in aten.mul.Tensor not being visited. +// We should investigate why that happens. + +// CHECK-LABEL: func.func @forward +func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) { + %0 = torch.prim.If %arg0 -> (!torch.tensor) { + torch.prim.If.yield %arg1 : !torch.tensor + } else { + torch.prim.If.yield %arg1 : !torch.tensor + } + %1 = torch.copy.to_vtensor %0 : !torch.vtensor + %2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor + return +} diff --git a/test/Dialect/Torch/verify-conversion-to-value-semantics.mlir b/test/Dialect/Torch/verify-conversion-to-value-semantics.mlir deleted file mode 100644 index 7ae6b1e19070..000000000000 --- a/test/Dialect/Torch/verify-conversion-to-value-semantics.mlir +++ /dev/null @@ -1,9 +0,0 @@ -// RUN: torch-mlir-opt -split-input-file -verify-diagnostics %s -torch-verify-conversion-to-value-semantics - -// ----- - -func.func @result_is_non_value_tensor(%arg: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> { - // @expected-error@+1 {{found a non-value tensor type, this is likely due to a missing case in the MaximizeValueSemantics pass}} - %neg = torch.aten.neg %arg : !torch.vtensor<[2],f32> -> !torch.tensor - return %arg : !torch.vtensor<[2],f32> -} diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 984c57487c9f..c003e4fa1c4c 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -182,7 +182,6 @@ cc_library( "lib/Dialect/Torch/Transforms/ShapeLibrary.cpp", "lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp", "lib/Dialect/Torch/Transforms/PassDetail.h", - "lib/Dialect/Torch/Transforms/VerifyConversionToValueSemantics.cpp", ], hdrs = [ "include/torch-mlir/Dialect/Torch/Transforms/Passes.h",