diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 28138edcb4787..f376ac0d881b9 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -111,7 +111,16 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { This pass assumes that TOSA ops are responsible for emitting error guards in case of shape mismatches. }]; - let constructor = "mlir::torch::createConvertTorchToTosaPass()"; + let options = [ + ListOption<"customOps", "custom-ops", "std::string", + "List of operation names that should be converted to tosa::custom", + "llvm::cl::ZeroOrMore"> + ]; + + let constructor = [{ + mlir::torch::createConvertTorchToTosaPass( + /*customOps=*/{}) + }]; } def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> { diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index a6d774a64db10..f7691b697bb08 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -16,7 +16,9 @@ namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToTosaPass(); +std::unique_ptr> createConvertTorchToTosaPass( + ArrayRef customOps + ); } } // namespace mlir diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index fd350da1d61ec..c975ce3961763 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -28,7 +28,19 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); /// Creates a pipeline that lowers from the torch backend contract to the /// TOSA backend contract. -void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); +struct TosaBackendPipelineOptions + : public PassPipelineOptions { + // If this option is true, custom complex operations. + // If this option is false, skip decomposition of complex operations. + Option custom{*this, "custom-ops", + llvm::cl::desc("custom complex operations."), + llvm::cl::init(true)}; + ListOption customOps{ + *this, "custom-ops", + llvm::cl::desc("List of ops to be converted to the backend."), + llvm::cl::ZeroOrMore}; +}; +void createTorchBackendToTosaBackendPipeline(OpPassManager &pm, const TosaBackendPipelineOptions &options); // Do not register the torch-to-mhlo pipeline if mhlo target is disabled #ifdef TORCH_MLIR_ENABLE_MHLO diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 5b23a58f09c2e..70908a32a5274 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3042,7 +3042,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // %4 = "tosa.reciprocal"(%3) : (tensor<2x1xf32>) -> tensor<2x1xf32> // %5 = "tosa.mul"(%2, %4) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<2x1xf32>) -> tensor<2x3xf32> // No-Decompose TOSA format: without -torch-decompose-complex-ops flag - // "tosa.custom(%x){dim = 1 : i64, identifier = "softmax"}" : (tensor<2x3xf32>) -> tensor<2x3xf32> + // "tosa.custom(%x){identifier = "softmax"}" : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32> // Check AtenSoftmaxIntOp first input is a tensor type. auto selfType = adaptor.self().getType().dyn_cast(); @@ -3059,21 +3059,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Create output type for tosa::CustomOp input auto outType = getTypeConverter()->convertType(op.getType()); - - // Create attributes for tosa::CustomOp input - // example: {dim = 1 : i64, identifier = "softmax"} - StringAttr nameIdAttr= rewriter.getStringAttr("identifier"); + // Create name attribute and multi-args for tosa::CustomOp input StringAttr nameValueAttr= rewriter.getStringAttr("softmax"); - StringAttr dimIdAttr= rewriter.getStringAttr("dim"); - IntegerAttr dimValueAttr = rewriter.getI64IntegerAttr(dim); - mlir::NamedAttribute nameAttr = mlir::NamedAttribute(nameIdAttr, nameValueAttr); - mlir::NamedAttribute dimAttr = mlir::NamedAttribute(dimIdAttr, dimValueAttr); - llvm::ArrayRef custom_attributes{nameAttr, dimAttr}; + auto dimTensor = tosa::getConstTensor( + rewriter, op.getOperation(), dim, {1}); + SmallVector inputOperands{adaptor.self(), dimTensor.value()}; // TODO unportable target hardware implementation of exp(%x) / sum(exp(%x), %dim) - rewriter.replaceOpWithNewOp(op, outType, adaptor.self(), - custom_attributes); - + rewriter.replaceOpWithNewOp(op, outType, nameValueAttr, + inputOperands); return success(); } @@ -3708,6 +3702,10 @@ class ConvertAtenCloneOp : public OpConversionPattern { namespace { class ConvertTorchToTosa : public ConvertTorchToTosaBase { public: + ConvertTorchToTosa() = default; + ConvertTorchToTosa(ArrayRef customOps) { + this->customOps = customOps; + } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -3904,7 +3902,9 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp); + if(customOps.front()=="torch.aten.softmax.int"){ + INSERT_ATENOP_PATTERN(AtenSoftmaxIntOp); + } INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); INSERT_ATENOP_PATTERN(ValsemVariantAtenCopyOp); @@ -3925,6 +3925,6 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { } // namespace std::unique_ptr> -mlir::torch::createConvertTorchToTosaPass() { - return std::make_unique(); +mlir::torch::createConvertTorchToTosaPass(ArrayRef customOps) { + return std::make_unique(customOps); } diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 3e5b389697078..118dd346c21e9 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -47,7 +47,7 @@ void mlir::torch::registerTorchConversionPasses() { "contract.", TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline); - mlir::PassPipelineRegistration<>( + mlir::PassPipelineRegistration( "torch-backend-to-tosa-backend-pipeline", "Pipeline lowering torch backend contract to TOSA backend " "contract.", @@ -96,8 +96,8 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( } void TorchConversion::createTorchBackendToTosaBackendPipeline( - OpPassManager &pm) { - pm.addNestedPass(createConvertTorchToTosaPass()); + OpPassManager &pm, const TosaBackendPipelineOptions &options) { + pm.addNestedPass(createConvertTorchToTosaPass(options.customOps)); // Perform rank broadcasting so TosaToLinalg pass works pm.addNestedPass(createTosaMakeBroadcastablePass()); diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 7e0a947f48f24..a762a37198400 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -929,21 +929,4 @@ func.func @torch.aten.to.dtype(%arg0: !torch.vtensor<[3,5],si64>) -> !torch.vten func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !torch.vtensor<[1,12,5,5],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,5,5],f32> { %0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32> return %0 : !torch.vtensor<[1,12,5,5],f32> -} - -// ----- -// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.none -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = "tosa.custom"(%[[VAL_1]]) {dim = 1 : i64, identifier = "softmax"} : (tensor<2x3xf32>) -> tensor<2x3xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[2,3],f32> -// CHECK: } -func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { - %none = torch.constant.none - %dim = torch.constant.int 1 - %ret = torch.aten.softmax.int %t, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32> - return %ret : !torch.vtensor<[2,3],f32> } \ No newline at end of file diff --git a/test/Conversion/TorchToTosa/custom.mlir b/test/Conversion/TorchToTosa/custom.mlir new file mode 100644 index 0000000000000..77e84767eab40 --- /dev/null +++ b/test/Conversion/TorchToTosa/custom.mlir @@ -0,0 +1,19 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa="custom-ops=torch.aten.softmax.int" -split-input-file -verify-diagnostics | FileCheck %s + +// ----- +// CHECK-LABEL: func.func @torch.aten.softmax.int$cst_dim( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.none +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> +// CHECK: %[[VAL_5:.*]] = "tosa.custom"(%[[VAL_1]], %[[VAL_4]]) {identifier = "softmax"} : (tensor<2x3xf32>, tensor<1xi64>) -> tensor<2x3xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[2,3],f32> +// CHECK: } +func.func @torch.aten.softmax.int$cst_dim(%t: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> { + %none = torch.constant.none + %dim = torch.constant.int 1 + %ret = torch.aten.softmax.int %t, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32> + return %ret : !torch.vtensor<[2,3],f32> +} \ No newline at end of file