diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 28138edcb4787..39003e704f16c 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -114,6 +114,38 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToTosaPass()"; } +def ConvertTorchToTosaCustom : Pass<"convert-torch-to-tosa-custom", "func::FuncOp"> { + let summary = "Convert Torch ops to TOSA custom ops"; + let description = [{ + The purpose to use tosa::custom is handle complex ops when we donnot + want to decompose them into simple ops. + The aten op name will used to construct a StringAttr as the identifier attribute for tosa::CustomOp. + Each input arg from Aten Dialect has to be converted to a tensor of number values as the + operand of tosa::CustomOp op. After convert, use ValueRange/SmallVector to include + all operand as the final input operands for tosa::CustomOp. + + Take softmax for example: + "aten.softmax.int"(%x,%dim): (tensor<2x3xf32>, int) -> tensor<2x3xf32> + Decompose : with -torch-decompose-complex-ops flag + %2 = "tosa.exp"(%x) : (tensor<2x3xf32>) -> tensor<2x3xf32> + %3 = "tosa.reduce_sum"(%2) {axis = %dim : i64} : (tensor<2x3xf32>) -> tensor<2x1xf32> + %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32> + %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32> + No-Decompose: with convert-torch-to-tosa-custom="custom-ops=torch.aten.softmax.int" + "tosa.custom(%x){identifier = "softmax"}" : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32> + }]; + let options = [ + ListOption<"customOps", "custom-ops", "std::string", + "List of operation names that should be converted to tosa::custom", + "llvm::cl::ZeroOrMore"> + ]; + + let constructor = [{ + mlir::torch::createConvertTorchToTosaCustomPass( + /*customOps=*/{}) + }]; +} + def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { let summary = "Convert recognized Torch ops to TMTensor/Linalg ops"; let description = [{ diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db10..0549cfddd10fd 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -17,6 +17,9 @@ namespace mlir { namespace torch { std::unique_ptr> createConvertTorchToTosaPass(); + +std::unique_ptr> +createConvertTorchToTosaCustomPass(ArrayRef customOps); } } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index fd350da1d61ec..5d8fbb71c63b5 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -28,7 +28,15 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); /// Creates a pipeline that lowers from the torch backend contract to the /// TOSA backend contract. -void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); +struct TosaBackendPipelineOptions + : public PassPipelineOptions { + ListOption customOps{ + *this, "custom-ops", + llvm::cl::desc("List of ops to be converted to tosa::CustomOp."), + llvm::cl::ZeroOrMore}; +}; +void createTorchBackendToTosaBackendPipeline( + OpPassManager &pm, const TosaBackendPipelineOptions &options); // Do not register the torch-to-mhlo pipeline if mhlo target is disabled #ifdef TORCH_MLIR_ENABLE_MHLO diff --git a/lib/Conversion/TorchToTosa/CMakeLists.txt b/lib/Conversion/TorchToTosa/CMakeLists.txt index e1f5142bd9242..f2bb518a9aad3 100644 --- a/lib/Conversion/TorchToTosa/CMakeLists.txt +++ b/lib/Conversion/TorchToTosa/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(TorchMLIRTorchToTosa TorchToTosa.cpp + TorchToTosaCustom.cpp TosaLegalizeUtils.cpp TosaLegalizeCommon.cpp diff --git a/lib/Conversion/TorchToTosa/TorchToTosaCustom.cpp b/lib/Conversion/TorchToTosa/TorchToTosaCustom.cpp new file mode 100644 index 0000000000000..5b0eddcd08d8d --- /dev/null +++ b/lib/Conversion/TorchToTosa/TorchToTosaCustom.cpp @@ -0,0 +1,150 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +template +class ConvertSelectiveAtenOpToTosaCustom : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange adaptor_operands = adaptor.getOperands(); + int num_operands = adaptor_operands.size(); + std::vector inputs_vec; + for (int i = 0; i < num_operands; i++) { + auto operand = *op.getODSOperands(i).begin(); + auto adaptor_operand_type = adaptor_operands[i].getType(); + + if (adaptor_operand_type + .isa()) { // Torch::ConstantIntOp + int64_t operand_tosa; + if (!matchPattern(operand, m_TorchConstantInt(&operand_tosa))) + return rewriter.notifyMatchFailure( + op, "unimplemented: operand should be a torch.constant.int"); + auto operand_tensor_int = tosa::getConstTensor( + rewriter, op.getOperation(), operand_tosa, {1}); + inputs_vec.push_back(operand_tensor_int.value()); + } else if (adaptor_operand_type + .isa()) { // Torch::ConstantFloatOp + double operand_tosa; + if (!matchPattern(operand, m_TorchConstantFloat(&operand_tosa))) + return rewriter.notifyMatchFailure( + op, "unimplemented: operand should be a torch.constant.float"); + auto operand_tensor_float = tosa::getConstTensor( + rewriter, op.getOperation(), operand_tosa, {1}); + inputs_vec.push_back(operand_tensor_float.value()); + } else if (adaptor_operand_type + .isa()) { // Torch::ValueTensorType + inputs_vec.push_back(*adaptor.getODSOperands(i).begin()); + } else { + // TODO Handle more types like !torch.list<...>, !torch.device, + // !torch.string, !torch.none, !torch.generator. + return rewriter.notifyMatchFailure( + op, + "unimplemented: inputs type. The input has to be int/float "); + } + } + // Create output type for tosa::CustomOp input + auto outType = this->getTypeConverter()->convertType(op.getType()); + + // Create operands for tosa::CustomOp + llvm::ArrayRef ref(inputs_vec.data(), inputs_vec.size()); + ValueRange custom_inputs(ref); + rewriter.replaceOpWithNewOp( + op, outType, op.getOperationName(), custom_inputs); + return success(); + } +}; + +} // namespace + +// ----------------------------------------------------------------------------- +// TorchToTosaCustom Pass +// ----------------------------------------------------------------------------- + +namespace { +class ConvertTorchToTosaCustom + : public ConvertTorchToTosaCustomBase { +public: + ConvertTorchToTosaCustom() = default; + ConvertTorchToTosaCustom(ArrayRef customOps) { + this->customOps = customOps; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + + std::unordered_map customOpsMap; + for (auto key : customOps) { + customOpsMap[key] = true; + } + +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + if (customOpsMap["torch.aten.softmax.int"]) { + INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp); + } + if (customOpsMap["torch.aten.rsub.Scalar"]) { + INSERT_ATENOP_PATTERN(AtenRsubScalarOp); + } +#undef INSERT_ATENOP_PATTERN + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchToTosaCustomPass( + ArrayRef customOps) { + return std::make_unique(customOps); +} diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 3e5b389697078..5bab722b11997 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() { "contract.", TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); - mlir::PassPipelineRegistration<>( + mlir::PassPipelineRegistration( "torch-backend-to-tosa-backend-pipeline", "Pipeline lowering torch backend contract to TOSA backend " "contract.", @@ -96,7 +96,11 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( } void TorchConversion::createTorchBackendToTosaBackendPipeline( - OpPassManager &pm) { + OpPassManager &pm, const TosaBackendPipelineOptions &options) { + if (!options.customOps.empty()) { + pm.addNestedPass( + createConvertTorchToTosaCustomPass(options.customOps)); + } pm.addNestedPass(createConvertTorchToTosaPass()); // Perform rank broadcasting so TosaToLinalg pass works pm.addNestedPass(createTosaMakeBroadcastablePass()); diff --git a/test/Conversion/TorchToTosa/custom.mlir b/test/Conversion/TorchToTosa/custom.mlir new file mode 100644 index 0000000000000..2d221b2abf203 --- /dev/null +++ b/test/Conversion/TorchToTosa/custom.mlir @@ -0,0 +1,39 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa-custom="custom-ops=torch.aten.softmax.int,torch.aten.rsub.Scalar" -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {identifier = "torch.aten.softmax.int"} : (tensor<2x3xf32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2x3xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[2,3],f32> +// CHECK: } +func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { + %dtype = torch.constant.int 1 + %dim = torch.constant.int 1 + %ret = torch.aten.softmax.int %t, %dim, %dtype : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> + return %ret : !torch.vtensor<[2,3],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.rsub.Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<3> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_6:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]], %[[VAL_5]]) {identifier = "torch.aten.rsub.Scalar"} : (tensor, tensor<1xi64>, tensor<1xi64>) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func.func @torch.aten.rsub.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %other = torch.constant.float 3.123400e+00 + %alpha = torch.constant.int 1 + %0 = torch.aten.rsub.Scalar %arg0, %other, %alpha : !torch.vtensor<[?,?],f32>, !torch.float, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} \ No newline at end of file