From c510133643167bb5a043c671648a3d26ba532f13 Mon Sep 17 00:00:00 2001 From: AmosLewis Date: Tue, 25 Oct 2022 11:45:03 -0700 Subject: [PATCH] [MLIR][TORCH] Add Pass for custom softmax ops 1. Add an option "customOps" for convert-torch-to-tosa 2. Enable multi input arg for tosa::custom instead use multi attributes --- include/torch-mlir/Conversion/Passes.td | 11 ++++++- .../Conversion/TorchToTosa/TorchToTosa.h | 4 ++- .../TorchConversion/Transforms/Passes.h | 14 +++++++- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 32 +++++++++---------- .../TorchConversion/Transforms/Passes.cpp | 6 ++-- test/Conversion/TorchToTosa/custom.mlir | 19 +++++++++++ 6 files changed, 64 insertions(+), 22 deletions(-) create mode 100644 test/Conversion/TorchToTosa/custom.mlir diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 28138edcb4787..f376ac0d881b9 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -111,7 +111,16 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { This pass assumes that TOSA ops are responsible for emitting error guards in case of shape mismatches. }]; - let constructor = "mlir::torch::createConvertTorchToTosaPass()"; + let options = [ + ListOption<"customOps", "custom-ops", "std::string", + "List of operation names that should be converted to tosa::custom", + "llvm::cl::ZeroOrMore"> + ]; + + let constructor = [{ + mlir::torch::createConvertTorchToTosaPass( + /*customOps=*/{}) + }]; } def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db10..f7691b697bb08 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -16,7 +16,9 @@ namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToTosaPass(); +std::unique_ptr> createConvertTorchToTosaPass( + ArrayRef customOps + ); } } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index fd350da1d61ec..80458a71f45b0 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -28,7 +28,19 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); /// Creates a pipeline that lowers from the torch backend contract to the /// TOSA backend contract. -void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); +struct TosaBackendPipelineOptions + : public PassPipelineOptions { + // If this option is true, custom complex operations. + // If this option is false, skip decomposition of complex operations. + Option enableCustomOps{*this, "enableCustomOps", + llvm::cl::desc("Enable custom complex operations."), + llvm::cl::init(true)}; + ListOption customOps{ + *this, "custom-ops", + llvm::cl::desc("List of ops to be converted to the backend."), + llvm::cl::ZeroOrMore}; +}; +void createTorchBackendToTosaBackendPipeline(OpPassManager &pm, const TosaBackendPipelineOptions &options); // Do not register the torch-to-mhlo pipeline if mhlo target is disabled #ifdef TORCH_MLIR_ENABLE_MHLO diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1e92e25fabdd9..6dd507e631f0f 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3042,7 +3042,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32> // %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32> // No-Decompose TOSA format: without -torch-decompose-complex-ops flag - // "tosa.custom(%x){dim = 1 : i64, identifier = "softmax"}" : (tensor<2x3xf32>) -> tensor<2x3xf32> + // "tosa.custom(%x){identifier = "softmax"}" : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32> // Check AtenSoftmaxIntOp first input is a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); @@ -3059,21 +3059,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Create output type for tosa::CustomOp input auto outType = getTypeConverter()->convertType(op.getType()); - - // Create attributes for tosa::CustomOp input - // example: {dim = 1 : i64, identifier = "softmax"} - StringAttr nameIdAttr= rewriter.getStringAttr("identifier"); + // Create name attribute and multi-args for tosa::CustomOp input StringAttr nameValueAttr= rewriter.getStringAttr("softmax"); - StringAttr dimIdAttr= rewriter.getStringAttr("dim"); - IntegerAttr dimValueAttr = rewriter.getI64IntegerAttr(dim); - mlir::NamedAttribute nameAttr = mlir::NamedAttribute(nameIdAttr, nameValueAttr); - mlir::NamedAttribute dimAttr = mlir::NamedAttribute(dimIdAttr, dimValueAttr); - llvm::ArrayRef custom_attributes{nameAttr, dimAttr}; + auto dimTensor = tosa::getConstTensor( + rewriter, op.getOperation(), dim, {1}); + SmallVector inputOperands{adaptor.self(), dimTensor.value()}; // TODO unportable target hardware implementation of exp(%x) / sum(exp(%x), %dim) - rewriter.replaceOpWithNewOp(op, outType, adaptor.self(), - custom_attributes); - + rewriter.replaceOpWithNewOp(op, outType, nameValueAttr, + inputOperands); return success(); } @@ -3708,6 +3702,10 @@ class ConvertAtenCloneOp : public OpConversionPattern { namespace { class ConvertTorchToTosa : public ConvertTorchToTosaBase { public: + ConvertTorchToTosa() = default; + ConvertTorchToTosa(ArrayRef customOps) { + this->customOps = customOps; + } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -3904,7 +3902,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp); + if(!customOps.empty()){ + INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp); + } INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(AtenCopyOp); @@ -3925,6 +3925,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { } // namespace std::unique_ptr> -mlir::torch::createConvertTorchToTosaPass() { - return std::make_unique(); +mlir::torch::createConvertTorchToTosaPass(ArrayRef customOps) { + return std::make_unique(customOps); } diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 3e5b389697078..118dd346c21e9 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() { "contract.", TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); - mlir::PassPipelineRegistration<>( + mlir::PassPipelineRegistration( "torch-backend-to-tosa-backend-pipeline", "Pipeline lowering torch backend contract to TOSA backend " "contract.", @@ -96,8 +96,8 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( } void TorchConversion::createTorchBackendToTosaBackendPipeline( - OpPassManager &pm) { - pm.addNestedPass(createConvertTorchToTosaPass()); + OpPassManager &pm, const TosaBackendPipelineOptions &options) { + pm.addNestedPass(createConvertTorchToTosaPass(options.customOps)); // Perform rank broadcasting so TosaToLinalg pass works pm.addNestedPass(createTosaMakeBroadcastablePass()); diff --git a/test/Conversion/TorchToTosa/custom.mlir b/test/Conversion/TorchToTosa/custom.mlir new file mode 100644 index 0000000000000..77e84767eab40 --- /dev/null +++ b/test/Conversion/TorchToTosa/custom.mlir @@ -0,0 +1,19 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa="custom-ops=torch.aten.softmax.int" -split-input-file -verify-diagnostics | FileCheck %s + +// ----- +// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_5:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]]) {identifier = "softmax"} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[2,3],f32> +// CHECK: } +func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { + %none = torch.constant.none + %dim = torch.constant.int 1 + %ret = torch.aten.softmax.int %t, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32> + return %ret : !torch.vtensor<[2,3],f32> +} \ No newline at end of file