From 5ecd1d31d305d2f137f7fc72eb06a1d4e26e4c43 Mon Sep 17 00:00:00 2001 From: Siavash Nazari Date: Sun, 6 Nov 2022 21:34:22 -0500 Subject: [PATCH] Add support to lower backend-legal operations as TOSA CustomOps --- include/torch-mlir/Conversion/Passes.td | 51 +++++ .../Conversion/TorchToTosa/TorchToTosa.h | 5 +- .../TorchConversion/Transforms/Passes.h | 11 +- lib/Conversion/TorchToTosa/CMakeLists.txt | 1 + .../TorchBackendLegalToTosaCustom.cpp | 194 ++++++++++++++++++ .../TorchConversion/Transforms/Passes.cpp | 6 +- python/torch_mlir/__init__.py | 5 +- .../TorchToTosa/selective_decomposition.mlir | 19 ++ 8 files changed, 286 insertions(+), 6 deletions(-) create mode 100644 lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp create mode 100644 test/Conversion/TorchToTosa/selective_decomposition.mlir diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 28138edcb4787..00f486ac0cb7f 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -105,6 +105,57 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; } +def ConvertTorchBackendLegalToTosaCustom + : Pass<"convert-torch-backend-legal-to-tosa-custom", "func::FuncOp"> { + let summary = "Convert Torch backend-legal operations to TOSA CustomOps."; + let description = [{ + This pass extends the selective decomposition patterns to TOSA backend, + such that it will convert operations marked as backend-legal by the user + for the TOSA backend to CustomOps in TOSA dialect. + The backend-legal operations will not be decomposed or lowered through + the existing Torch to TOSA conversion patterns, even if such patterns + exist in the subsequent Torch to TOSA conversion pass; instead a TOSA + CustomOp will be created with the specifications below: + 1. The 'identifier' attribute in the CustomOp will be the operation + name of the corresponding ATen operation. + 2. All inputs to the ATen operation will be converted into legal TOSA + operations. There will be no distiction between inputs and constant + attributes; all will be treated as inputs to the CustomOp. It is up + to the consuming backend to deduce them based on the ATen operation + semantics. + 3. Since TOSA conversion pattern of Torch operations are responsible + to consume and convet the constant attributes (such as 'axis' for a + reduction operation), the TOSA CustomOp conversion will also match + and rewrite these attributes as TOSA ConstOps as well. Specifically: + - torch.constant.bool -> tosa.ConstOp (of tensor type i1) + - torch.constant.int -> tosa.ConstOp (of tensor type i64) + - torch.constant.float -> tosa.ConstOp (of tensor type f32) + - torch.prim.ListConstruct -> tosa.ConstOp (of tensor type i64) + - torch.constant.none -> tosa.CustomOp (of tensor type i1) + The 'identifier' attribute of this + CustomOp is 'torch.constant.none' + All other Torch ATen operations will be lowered to TOSA by the Torch + to TOSA conversion pass after this one. + TODO: Extend the contract for other Torch constant operations such as: + - torch.constant.device + - torch.constant.number + - torch.constant.str + 4. The order of the input operands of the backend-legal Torch operation + preserved and the TOSA CustomOp will have the same order. + TODO: Update this pass to populate the 'config' attribute of the TOSA + CustomOp to establish a proper operand mapping scheme between + the backend-legal Torch operation and the TOSA CustomOp. + (See the commit at https://github.com/llvm/llvm-project/commit/d94ee70f4f01e4d9eec49e02eff57a5655618401) + }]; + let options = [ListOption< + "customOps", "custom-ops", "std::string", + "List of operations considered backend-legal and should be converted " + "to converted to CustomOps in TOSA dialect.", + "llvm::cl::ZeroOrMore">]; + let constructor = + "mlir::torch::createConvertTorchBackendLegalToTosaCustomPass(/*customOps=*/{})"; +} + def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { let summary = "Convert Torch ops to TOSA ops"; let description = [{ diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db10..461979dc896ae 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -16,8 +16,11 @@ namespace mlir { namespace torch { +std::unique_ptr> +createConvertTorchBackendLegalToTosaCustomPass(ArrayRef customOps); + std::unique_ptr> createConvertTorchToTosaPass(); -} +} // namespace torch } // namespace mlir #endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index fd350da1d61ec..bd40a11642140 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -28,7 +28,16 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); /// Creates a pipeline that lowers from the torch backend contract to the /// TOSA backend contract. -void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); +struct TosaBackendPipelineOptions + : public PassPipelineOptions { + ListOption backendLegalOps{ + *this, "custom-ops", + llvm::cl::desc("List of operations considered backend-legal and " + "should be converted to CustomOps in TOSA dialect.")}; +}; + +void createTorchBackendToTosaBackendPipeline( + OpPassManager &pm, const TosaBackendPipelineOptions &options); // Do not register the torch-to-mhlo pipeline if mhlo target is disabled #ifdef TORCH_MLIR_ENABLE_MHLO diff --git a/lib/Conversion/TorchToTosa/CMakeLists.txt b/lib/Conversion/TorchToTosa/CMakeLists.txt index e1f5142bd9242..d83322b9cdbf1 100644 --- a/lib/Conversion/TorchToTosa/CMakeLists.txt +++ b/lib/Conversion/TorchToTosa/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa + TorchBackendLegalToTosaCustom.cpp TorchToTosa.cpp TosaLegalizeUtils.cpp TosaLegalizeCommon.cpp diff --git a/lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp b/lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp new file mode 100644 index 0000000000000..85509f6c22e0f --- /dev/null +++ b/lib/Conversion/TorchToTosa/TorchBackendLegalToTosaCustom.cpp @@ -0,0 +1,194 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +class ConvertBackendLegalAtenOp : public ConversionPattern { +public: + ArrayRef customOps; + + ConvertBackendLegalAtenOp(TypeConverter &typeConverter, MLIRContext *context, + ArrayRef customOps) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context) { + this->customOps = customOps; + } + + Value matchConstantAttributeOp(Value inputOperand, PatternRewriter &rewriter, + Operation *backendLegalOp) const { + // Get the Torch Op to find the constant attributes attached to the + // backendLegalOp + Value torchInputOperand = inputOperand; + if (auto unrealizedCastOp = dyn_cast_or_null( + inputOperand.getDefiningOp())) { + torchInputOperand = unrealizedCastOp.getInputs()[0]; + } + // Handle the special case where input operand is an argument to the module + // function + if (!torchInputOperand.getDefiningOp()) + return inputOperand; + + return TypeSwitch(torchInputOperand.getDefiningOp()) + .Case([&](Operation *boolOperand) -> Value { + bool boolConstAttr; + if (matchPattern(boolOperand, m_TorchConstantBool(&boolConstAttr))) { + return tosa::getConstTensor(rewriter, backendLegalOp, + boolConstAttr, {}) + .value(); + } + return nullptr; + }) + .Case([&](Operation *intOperand) -> Value { + int64_t intConstAttr; + if (matchPattern(intOperand, m_TorchConstantInt(&intConstAttr))) { + return tosa::getConstTensor(rewriter, backendLegalOp, + intConstAttr, {}) + .value(); + } + return nullptr; + }) + .Case([&](Operation *floatOperand) -> Value { + double floatConstAttr; + if (matchPattern(floatOperand, + m_TorchConstantFloat(&floatConstAttr))) { + return tosa::getConstTensor(rewriter, backendLegalOp, + floatConstAttr, {}) + .value(); + } + return nullptr; + }) + .Case( + [&](Operation *intListConstructOperand) -> Value { + SmallVector intConstListAttr; + if (matchPattern(intListConstructOperand, + m_TorchConstantIntList(intConstListAttr))) { + return tosa::getConstTensor( + rewriter, backendLegalOp, intConstListAttr, + {static_cast(intConstListAttr.size())}) + .value(); + } + return nullptr; + }) + .Case([&](Operation *noneOperand) -> Value { + auto noneCustomOp = rewriter.create( + backendLegalOp->getLoc(), + RankedTensorType::get({}, rewriter.getIntegerType(1)), + rewriter.getStringAttr("torch.none"), ValueRange{}); + return noneCustomOp.getResult(0); + }) + .Default([&](Operation *defaultOperand) { return inputOperand; }); + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + + for (auto customOpName : customOps) { + if (customOpName == op->getName().getStringRef()) { + SmallVector customOpInputOperands; + + for (auto operand : operands) { + Value customOpInputOperand = + matchConstantAttributeOp(operand, rewriter, op); + if (!customOpInputOperand) { + return rewriter.notifyMatchFailure( + op, + "failed to match the constant operand of the backend-legal Op"); + } + customOpInputOperands.push_back(customOpInputOperand); + } + SmallVector customOpResultTypes; + auto convertTypesResult = getTypeConverter()->convertTypes( + op->getResultTypes(), customOpResultTypes); + if (convertTypesResult.failed()) + return failure(); + rewriter.replaceOpWithNewOp( + op, TypeRange{customOpResultTypes}, llvm::StringRef(customOpName), + ValueRange{customOpInputOperands}); + + return success(); + } + } + return failure(); + } +}; + +} // namespace + +// ----------------------------------------------------------------------------- +// TorchBackendLegalToTosaCustom Pass +// ----------------------------------------------------------------------------- + +namespace { +class ConvertTorchBackendLegalToTosaCustom + : public ConvertTorchBackendLegalToTosaCustomBase< + ConvertTorchBackendLegalToTosaCustom> { +public: + ConvertTorchBackendLegalToTosaCustom() = default; + ConvertTorchBackendLegalToTosaCustom(ArrayRef customOps) { + this->customOps = customOps; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + + patterns.add(typeConverter, context, customOps); + + for (std::string opName : customOps) { + // target.addLegalOp(OperationName(opName, context)); + } + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchBackendLegalToTosaCustomPass( + ArrayRef customOps) { + return std::make_unique(customOps); +} diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 3e5b389697078..b34e5a5241b22 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() { "contract.", TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); - mlir::PassPipelineRegistration<>( + mlir::PassPipelineRegistration( "torch-backend-to-tosa-backend-pipeline", "Pipeline lowering torch backend contract to TOSA backend " "contract.", @@ -96,7 +96,9 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( } void TorchConversion::createTorchBackendToTosaBackendPipeline( - OpPassManager &pm) { + OpPassManager &pm, const TosaBackendPipelineOptions &options) { + pm.addNestedPass( + createConvertTorchBackendLegalToTosaCustomPass(options.backendLegalOps)); pm.addNestedPass(createConvertTorchToTosaPass()); // Perform rank broadcasting so TosaToLinalg pass works pm.addNestedPass(createTosaMakeBroadcastablePass()); diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 93239e7ac9221..62116de99af07 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -192,7 +192,7 @@ def compile(model: torch.nn.Module, # tensor to a list of a single tensor to make the API more ergonomic. if isinstance(example_args, (torch.Tensor, TensorPlaceholder)): example_args = (example_args,) - + # If users passed in anything other than tensors or a list of tensors (e.g. # a dictionary), we can't handle it. if not isinstance(example_args, Sequence): @@ -285,10 +285,11 @@ def compile(model: torch.nn.Module, if output_type == OutputType.TORCH: return mb.module + option_string = "{custom-ops=" + ",".join(backend_legal_ops) + "}" if output_type == OutputType.TOSA: run_pipeline_with_repro_report( mb.module, - "torch-backend-to-tosa-backend-pipeline", + f"torch-backend-to-tosa-backend-pipeline{option_string}", "Lowering Torch Backend IR -> TOSA Backend IR") if verbose: print("\n====================") diff --git a/test/Conversion/TorchToTosa/selective_decomposition.mlir b/test/Conversion/TorchToTosa/selective_decomposition.mlir new file mode 100644 index 0000000000000..f1bedfcd81c4e --- /dev/null +++ b/test/Conversion/TorchToTosa/selective_decomposition.mlir @@ -0,0 +1,19 @@ +// RUN: torch-mlir-opt <%s -convert-torch-backend-legal-to-tosa-custom="custom-ops=torch.aten._softmax" -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten._softmax( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,4],f32> -> tensor<4x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int -1 +// CHECK: %[[VAL_3:.*]] = torch.constant.bool false +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<-1> : tensor} : () -> tensor +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {identifier = "torch.aten._softmax"} : (tensor<4x4xf32>, tensor, tensor) -> tensor<4x4xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4x4xf32> -> !torch.vtensor<[4,4],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[4,4],f32> +// CHECK: } +func.func @torch.aten._softmax(%arg0: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { + %int-1 = torch.constant.int -1 + %false = torch.constant.bool false + %0 = torch.aten._softmax %arg0, %int-1, %false : !torch.vtensor<[4,4],f32>, !torch.int, !torch.bool -> !torch.vtensor<[4,4],f32> + return %0 : !torch.vtensor<[4,4],f32> +}