From 33eef15e428f848e3848d1038ed71faab893a686 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Tue, 30 Apr 2024 14:36:40 -0400 Subject: [PATCH] Support onnx.If (#2825) This is probably a decent PR for learning about blocks and regions. If you're here to learn about that, consider also looking at lib/Conversion/TorchToSCF/TorchToSCF.cpp While this doesn't include an e2e test, it is tested downstream in https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/onnx/operators/If/model.py --------- Co-authored-by: Xida Ren --- .../Conversion/TorchOnnxToTorch/Patterns.h | 25 +++++++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 54 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 12 ++--- .../test_suite/diagonal.py | 31 +++++++++++ python/torch_mlir/extras/onnx_importer.py | 8 ++- test/Conversion/TorchOnnxToTorch/ops/if.mlir | 20 +++++++ 6 files changed, 141 insertions(+), 9 deletions(-) create mode 100644 test/Conversion/TorchOnnxToTorch/ops/if.mlir diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index d3260500cfa8..3230cc8b46a0 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -97,6 +97,31 @@ struct OpBinder { return success(); } + ParseResult tensorResultTypes(llvm::SmallVector &typeList) { + for (auto result : op->getResults()) { + auto t = toValidTensorType(result.getType()); + if (!t) + return failure(); + typeList.push_back(t); + } + return success(); + } + + // The importer imports Onnx.GraphProto attributes as regions attached to the + // op. + ParseResult getRegionAtIndex(mlir::Region *®ion, int64_t idx) { + if (idx >= op->getNumRegions()) + return failure(); + + region = &op->getRegion(idx); + + if (region == nullptr) { + return failure(); + } + + return success(); + } + ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) { if (idx >= op->getNumResults()) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 7a150794cb4b..1f1e2e5d7f0c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( alignCorners); return success(); }); + patterns.onOp( + "If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value conditionTensor; + if (binder.tensorOperand(conditionTensor)) { + return rewriter.notifyMatchFailure(binder.op, + "condition bind failure"); + } + + auto conditionType = + conditionTensor.getType().cast(); + if (!conditionType || conditionType.getSizes().size() != 1) + return rewriter.notifyMatchFailure( + binder.op, "condition must have one single element per " + "https://onnx.ai/onnx/operators/onnx__If.html"); + auto conditionInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + conditionTensor); + auto conditionBool = rewriter.create( + binder.getLoc(), rewriter.getType(), conditionInt); + + llvm::SmallVector resultTypes; + if (binder.tensorResultTypes(resultTypes)) { + return rewriter.notifyMatchFailure(binder.op, + "result type bind failure"); + } + + Region *thenRegion, *elseRegion; + if (binder.getRegionAtIndex(elseRegion, 0) || + binder.getRegionAtIndex(thenRegion, 1)) { + return rewriter.notifyMatchFailure(binder.op, "region bind failure"); + } + + auto primIfOp = rewriter.create( + binder.getLoc(), TypeRange(resultTypes), conditionBool); + + auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) { + rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin()); + }; + inlineIfCase(*thenRegion, primIfOp.getThenRegion()); + inlineIfCase(*elseRegion, primIfOp.getElseRegion()); + + auto replaceTerminator = [&](Region ®ion) { + PatternRewriter::InsertionGuard guard(rewriter); + Operation *terminator = region.front().getTerminator(); + rewriter.setInsertionPoint(terminator); + rewriter.replaceOpWithNewOp( + terminator, terminator->getOperands()); + }; + replaceTerminator(primIfOp.getThenRegion()); + replaceTerminator(primIfOp.getElseRegion()); + + rewriter.replaceOp(binder.op, primIfOp.getResults()); + return success(); + }); patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fcb7e053a0db..25d8fa9be5a2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2562,16 +2562,12 @@ "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", "_SoftmaxModule_basic", + # Failure - onnx_import # Failure - onnx_lowering: onnx.AveragePool "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", - # Failure - onnx_lowering: onnx.If - "DiagonalModule_basic", - "DiagonalModule_nonsquare", - "DiagonalModule_transposed", - "DiagonalModule_with_dims", - "DiagonalModule_with_dims_and_offset", - "DiagonalModule_with_negative_dims", - "DiagonalModule_with_offset", + # these diagonal modules are currently failing due to dynamic shape. + # We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead. + # when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here. "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", # Failure - onnx_lowering: onnx.MaxPool diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py index 6371f9a8d7a7..3bd3796dad8e 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py @@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils): # ============================================================================== +class DiagonalWithStaticShapeModule(torch.nn.Module): + """ + Diagonal with static shape. The other diagonal modules are failing in onnx + because DecomoposeAtenEyeMOp requires constants n, m, which are only constant + when the shape is static. + + Please remove this module and associated test once the issue is fixed. + """ + + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([5, 9], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.diagonal(a) + + +@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule()) +def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 9)) + + +# ============================================================================== + + class DiagonalTransposedModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 8d0e4cf5a8e1..e0d3529d942e 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -347,8 +347,14 @@ def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]): continue elif handler is False: # Active error. + # try matching attribute type ID to name for a more descriptive error message + try: + attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type) + except ValueError: + attr_type_name = "UNKNOWN" raise OnnxImportError( - f"ONNX importer does not support generic node attribute type {attr_type}. " + f"ONNX importer does not support generic node attribute type {attr_type_name} " + f"with ID {attr_type}. " f"This likely means that this is a special node which requires specific " f"handling in the importer: {onnx_attr}" ) diff --git a/test/Conversion/TorchOnnxToTorch/ops/if.mlir b/test/Conversion/TorchOnnxToTorch/ops/if.mlir new file mode 100644 index 000000000000..1d95a3f5fc3a --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/ops/if.mlir @@ -0,0 +1,20 @@ +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s + +// CHECK-LABEL: func.func @test_ifop_basic +// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>) +// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32> +// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32> +// CHECK-DAG: } else { +// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32> +// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32> +func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + %0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> { + %1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> + torch.operator_terminator %1 : !torch.vtensor<[1],f32> + }, { + %1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> + torch.operator_terminator %1 : !torch.vtensor<[1],f32> + } + return %0 : !torch.vtensor<[1],f32> +}