diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 28138edcb4787..5ec4b71c111e5 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -105,6 +105,61 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; } +def ConvertTorchBackendLegalToTosaCustom + : Pass<"convert-torch-backend-legal-to-tosa-custom", "func::FuncOp"> { + let summary = "Convert Torch backend-legal operations to TOSA CustomOps."; + let description = [{ + This pass extends the selective decomposition patterns to TOSA backend, + such that it will convert operations marked to custom-ops by the user + for the TOSA backend to CustomOps in TOSA dialect. + The backend-custom operations will not be decomposed or lowered through + the existing Torch to TOSA conversion patterns, even if such patterns + exist in the subsequent Torch to TOSA conversion pass; instead a TOSA + CustomOp will be created with the specifications below: + 1. The 'identifier' attribute in the CustomOp will be the operation + name of the corresponding ATen operation. + 2. The 'config' attribute in the CustomOp will be the specification + of the compilation pipeline, i.e. "torch_mlir" + 3. All inputs to the ATen operation will be converted into legal TOSA + operations. There will be no distiction between inputs and constant + attributes; all will be treated as inputs to the CustomOp. It is up + to the consuming backend to deduce them based on the ATen operation + semantics. + 4. Since TOSA conversion pattern of Torch operations are responsible + to consume and convet the constant attributes (such as 'axis' for a + reduction operation), the TOSA CustomOp conversion will also match + and rewrite these attributes as TOSA ConstOps as well. Specifically: + - !torch.bool -> tosa.ConstOp (of tensor type i64) + - !torch.int -> tosa.ConstOp (of tensor type i64) + - !torch.float -> tosa.ConstOp (of tensor type f32) + - !torch.list -> tosa.ConstOp (of tensor type i64) + - !torch.none -> tosa.CustomOp (of tensor type i1) + The 'identifier' attribute of this + CustomOp is 'torch.constant.none' + - !torch.str -> tosa.CustomOp (of tensor type i1) + The 'identifier' attribute of this + CustomOp is 'torch.constant.str' + All other Torch ATen operations will be lowered to TOSA by the Torch + to TOSA conversion pass after this one. + TODO: Extend the contract for other Torch constant operations such as: + - torch.constant.device + - torch.constant.number + 5. The input operands of the backend-legal Torch operation and the TOSA + CustomOp will have the same order. + TODO: Update this pass to populate the 'config' attribute of the TOSA + CustomOp to establish a proper operand mapping scheme between + the backend-legal Torch operation and the TOSA CustomOp. + (See the commit at https://github.com/llvm/llvm-project/commit/d94ee70f4f01e4d9eec49e02eff57a5655618401) + }]; + let options = [ListOption< + "customOps", "custom-ops", "std::string", + "List of operations considered backend-legal that should be converted to" + " CustomOps in TOSA dialect.", + "llvm::cl::ZeroOrMore">]; + let constructor = + "mlir::torch::createConvertTorchBackendLegalToTosaCustomPass(/*customOps=*/{})"; +} + def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { let summary = "Convert Torch ops to TOSA ops"; let description = [{ diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db10..461979dc896ae 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -16,8 +16,11 @@ namespace mlir { namespace torch { +std::unique_ptr> +createConvertTorchBackendLegalToTosaCustomPass(ArrayRef customOps); + std::unique_ptr> createConvertTorchToTosaPass(); -} +} // namespace torch } // namespace mlir #endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index fd350da1d61ec..2729db73335e4 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -28,7 +28,16 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); /// Creates a pipeline that lowers from the torch backend contract to the /// TOSA backend contract. -void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); +struct TosaBackendPipelineOptions + : public PassPipelineOptions { + ListOption customOps{ + *this, "custom-ops", + llvm::cl::desc("List of operations considered backend-legal and " + "should be converted to CustomOps in TOSA dialect.")}; +}; + +void createTorchBackendToTosaBackendPipeline( + OpPassManager &pm, const TosaBackendPipelineOptions &options); // Do not register the torch-to-mhlo pipeline if mhlo target is disabled #ifdef TORCH_MLIR_ENABLE_MHLO diff --git a/lib/Conversion/TorchToTosa/CMakeLists.txt b/lib/Conversion/TorchToTosa/CMakeLists.txt index e1f5142bd9242..d83322b9cdbf1 100644 --- a/lib/Conversion/TorchToTosa/CMakeLists.txt +++ b/lib/Conversion/TorchToTosa/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa + TorchBackendLegalToTosaCustom.cpp TorchToTosa.cpp TosaLegalizeUtils.cpp TosaLegalizeCommon.cpp diff --git a/lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp b/lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp new file mode 100644 index 0000000000000..6286b6a5ada4e --- /dev/null +++ b/lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp @@ -0,0 +1,224 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +class ConvertBackendLegalAtenOpToCustomOp : public ConversionPattern { +public: + SetVector customOps; + + ConvertBackendLegalAtenOpToCustomOp(TypeConverter &typeConverter, + MLIRContext *context, + ArrayRef customOps) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) { + this->customOps = SetVector(customOps.begin(), customOps.end()); + } + + Value convertOperandToTensor(Value inputOperand, PatternRewriter &rewriter, + Operation *backendLegalOp) const { + // Get the Torch Op to find the constant attributes attached to the + // backendLegalOp + Value torchInputOperand = inputOperand; + if (auto unrealizedCastOp = dyn_cast_or_null( + inputOperand.getDefiningOp())) { + torchInputOperand = unrealizedCastOp.getInputs()[0]; + } + // Handle the special case where input operand is an argument to the module + // function + if (!torchInputOperand.getDefiningOp()) + return inputOperand; + + return TypeSwitch(torchInputOperand.getDefiningOp()) + .Case([&](Operation *boolOperand) -> Value { + bool boolConstAttr; + if (matchPattern(boolOperand, m_TorchConstantBool(&boolConstAttr))) { + return tosa::getConstTensor(rewriter, backendLegalOp, + boolConstAttr, {}) + .value(); + } + return nullptr; + }) + // TODO Add support for converting "torch.constant.device" + .Case( + [&](Operation *strOperand) -> Value { return nullptr; }) + .Case([&](Operation *intOperand) -> Value { + int64_t intConstAttr; + if (matchPattern(intOperand, m_TorchConstantInt(&intConstAttr))) { + return tosa::getConstTensor(rewriter, backendLegalOp, + intConstAttr, {}) + .value(); + } + return nullptr; + }) + .Case([&](Operation *floatOperand) -> Value { + double floatConstAttr; + if (matchPattern(floatOperand, + m_TorchConstantFloat(&floatConstAttr))) { + return tosa::getConstTensor(rewriter, backendLegalOp, + floatConstAttr, {}) + .value(); + } + return nullptr; + }) + .Case([&](Operation *noneOperand) -> Value { + auto noneCustomOp = rewriter.create( + backendLegalOp->getLoc(), + RankedTensorType::get({}, rewriter.getIntegerType(1)), + rewriter.getStringAttr("torch.none"), ValueRange{}); + return noneCustomOp.getResult(0); + }) + // TODO Add support for converting "torch.constant.number" + .Case( + [&](Operation *strOperand) -> Value { return nullptr; }) + .Case([&](Operation *strOperand) -> Value { + std::string strConstAttr; + if (matchPattern(strOperand, m_TorchConstantStr(strConstAttr))) { + auto strCustomOp = rewriter.create( + backendLegalOp->getLoc(), + RankedTensorType::get({}, rewriter.getIntegerType(8)), + rewriter.getStringAttr("torch.str"), ValueRange{}); + strCustomOp.getOperation()->setAttr( + rewriter.getStringAttr("str"), + rewriter.getStringAttr(strConstAttr)); + return strCustomOp.getResult(0); + } + return nullptr; + }) + .Case( + [&](Operation *intListConstructOperand) -> Value { + SmallVector intConstListAttr; + if (matchPattern(intListConstructOperand, + m_TorchConstantIntList(intConstListAttr))) { + return tosa::getConstTensor( + rewriter, backendLegalOp, intConstListAttr, + {static_cast(intConstListAttr.size())}) + .value(); + } + return nullptr; + }) + .Default([&](Operation *defaultOperand) { return inputOperand; }); + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + + if (customOps.contains(op->getName().getStringRef())) { + SmallVector customOpInputOperands; + + for (auto operand : operands) { + Value customOpInputOperand = + convertOperandToTensor(operand, rewriter, op); + if (!customOpInputOperand) { + return rewriter.notifyMatchFailure( + op, + "failed to match the constant operand of the backend-legal Op"); + } + customOpInputOperands.push_back(customOpInputOperand); + } + SmallVector customOpResultTypes; + auto convertTypesResult = getTypeConverter()->convertTypes( + op->getResultTypes(), customOpResultTypes); + if (convertTypesResult.failed()) + return rewriter.notifyMatchFailure( + op, "failed to convert TOSA CustomOp result types; Only tensor " + "types are supported for the resutls."); + auto tosaCustomOp = rewriter.create( + op->getLoc(), TypeRange{customOpResultTypes}, + llvm::StringRef(op->getName().stripDialect()), + ValueRange{customOpInputOperands}); + // TODO: Use the 'config' StringAttr in the TOSA CustomOp interface + // instead. This requires the LLVM version uplift to include new TOSA + // CustomOp definition + tosaCustomOp.getOperation()->setAttr( + rewriter.getStringAttr("config"), + rewriter.getStringAttr("torch_mlir")); + rewriter.replaceOp(op, tosaCustomOp.getResults()); + return success(); + } + return failure(); + } +}; + +} // namespace + +// ----------------------------------------------------------------------------- +// TorchBackendLegalToTosaCustom Pass +// ----------------------------------------------------------------------------- + +namespace { +class ConvertTorchBackendLegalToTosaCustom + : public ConvertTorchBackendLegalToTosaCustomBase< + ConvertTorchBackendLegalToTosaCustom> { +public: + ConvertTorchBackendLegalToTosaCustom() = default; + ConvertTorchBackendLegalToTosaCustom(ArrayRef customOps) { + this->customOps = customOps; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + + patterns.add(typeConverter, context, + customOps); + + for (std::string opName : customOps) { + target.addIllegalOp(OperationName(opName, context)); + } + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchBackendLegalToTosaCustomPass( + ArrayRef customOps) { + return std::make_unique(customOps); +} diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 30054c13a4d77..6a45f90da9e2a 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() { "contract.", TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); - mlir::PassPipelineRegistration<>( + mlir::PassPipelineRegistration( "torch-backend-to-tosa-backend-pipeline", "Pipeline lowering torch backend contract to TOSA backend " "contract.", @@ -96,7 +96,9 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( } void TorchConversion::createTorchBackendToTosaBackendPipeline( - OpPassManager &pm) { + OpPassManager &pm, const TosaBackendPipelineOptions &options) { + pm.addNestedPass( + createConvertTorchBackendLegalToTosaCustomPass(options.customOps)); pm.addNestedPass(createConvertTorchToTosaPass()); // Perform rank broadcasting so TosaToLinalg pass works pm.addNestedPass(createTosaMakeBroadcastablePass()); diff --git a/python/torch_mlir/backends/tosa/__init__.py b/python/torch_mlir/backends/tosa/__init__.py new file mode 100644 index 0000000000000..3eaae5ef38048 --- /dev/null +++ b/python/torch_mlir/backends/tosa/__init__.py @@ -0,0 +1,70 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from typing import Optional, Sequence, Union + +import torch +import torch_mlir + +from ...compiler_utils import run_pipeline_with_repro_report + +def compile(model: torch.nn.Module, + example_args: Union[torch_mlir._example_arg, Sequence[torch_mlir._example_arg]], + use_tracing: bool = False, + ignore_traced_shapes = False, + backend_legal_ops: Optional[Sequence[str]] = None, + backend_custom_ops: Optional[Sequence[str]] = None, + verbose: bool = False): + """Convert a PyTorch model to TOSA IR in MLIR. + + Args: + model: The PyTorch model to convert. + example_args: A list of example arguments to use when inferring the + shapes of the arguments to `forward` method of the model. + A single tensor is treated as a list of a single tensor. + A TensorPlaceholder object is also allowed in the place of any + Tensor. + output_type: The kind of output to produce. See `OutputType` for more + details. + use_tracing: If True, use `torch.jit.trace` to convert the model to + JIT IR rather than `torch.jit.script`. + ignore_traced_shapes: If True, ignore the shapes that were observed + during tracing. This should only be used if one knows that the + original traced program would result in the same trace (modulo + shapes) for all shape combinations implied by any + `TensorPlaceholder`'s used as `example_args`. Also, + strictly-speaking, this option covers dtypes too, but we just say + "shapes" to be succinct. + backend_legal_ops: A list of ops that should be considered legal for + the backend. An op that is considered legal will not be decomposed. + This option is only valid with the `"torch"` output type. + backend_custom_ops: A list of ops to be converted to the custom ops in + the backend dialect. + verbose: If true, print extra information about the conversion. + + Returns: + An MLIR module that contains the converted model in TOSA IR. + """ + + if backend_legal_ops is None: + backend_legal_ops = torch_mlir.BACKEND_LEGAL_OPS.get(torch_mlir.OutputType.TOSA, []) + + module = torch_mlir.compile(model, example_args, torch_mlir.OutputType.TORCH, use_tracing, + ignore_traced_shapes, backend_legal_ops, verbose) + + if backend_custom_ops is None: + backend_custom_ops = [] + + backend_option_string = "{custom-ops=" + ",".join(backend_custom_ops) + "}" + run_pipeline_with_repro_report( + module, + f"builtin.module(torch-backend-to-tosa-backend-pipeline{backend_option_string})", + "Lowering Torch Backend IR -> TOSA Backend IR") + if verbose: + print("\n====================") + print("TOSA Backend IR") + print(module) + return module + diff --git a/test/Conversion/TorchToTosa/selective_decomposition.mlir b/test/Conversion/TorchToTosa/selective_decomposition.mlir new file mode 100644 index 0000000000000..0f863de840133 --- /dev/null +++ b/test/Conversion/TorchToTosa/selective_decomposition.mlir @@ -0,0 +1,19 @@ +// RUN: torch-mlir-opt <%s -convert-torch-backend-legal-to-tosa-custom="custom-ops=torch.aten._softmax" -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten._softmax( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,4],f32> -> tensor<4x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int -1 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<-1> : tensor} : () -> tensor +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {config = "torch_mlir", identifier = "aten._softmax"} : (tensor<4x4xf32>, tensor, tensor) -> tensor<4x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x4xf32> -> !torch.vtensor<[4,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,4],f32> +// CHECK: } +func.func @torch.aten._softmax(%arg0: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { + %int-1 = torch.constant.int -1 + %false = torch.constant.bool false + %0 = torch.aten._softmax %arg0, %int-1, %false : !torch.vtensor<[4,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> +}