Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support onnx.If #2825

Merged
merged 17 commits into from
Apr 30, 2024
Merged
25 changes: 25 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,31 @@ struct OpBinder {
return success();
}

ParseResult tensorResultTypes(llvm::SmallVector<mlir::Type> &typeList) {
for (auto result : op->getResults()) {
auto t = toValidTensorType(result.getType());
if (!t)
return failure();
typeList.push_back(t);
}
return success();
}

// The importer imports Onnx.GraphProto attributes as regions attached to the
// op.
ParseResult getRegionAtIndex(mlir::Region *&region, int64_t idx) {
if (idx >= op->getNumRegions())
return failure();

region = &op->getRegion(idx);

if (region == nullptr) {
renxida marked this conversation as resolved.
Show resolved Hide resolved
return failure();
}

return success();
}

ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
int64_t idx) {
if (idx >= op->getNumResults())
Expand Down
54 changes: 54 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
alignCorners);
return success();
});
patterns.onOp(
"If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value conditionTensor;
if (binder.tensorOperand(conditionTensor)) {
return rewriter.notifyMatchFailure(binder.op,
"condition bind failure");
}

auto conditionType =
conditionTensor.getType().cast<Torch::ValueTensorType>();
if (!conditionType || conditionType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(
binder.op, "condition must have one single element per "
"https://onnx.ai/onnx/operators/onnx__If.html");
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
conditionTensor);
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);

llvm::SmallVector<mlir::Type> resultTypes;
if (binder.tensorResultTypes(resultTypes)) {
return rewriter.notifyMatchFailure(binder.op,
"result type bind failure");
}

Region *thenRegion, *elseRegion;
if (binder.getRegionAtIndex(elseRegion, 0) ||
binder.getRegionAtIndex(thenRegion, 1)) {
return rewriter.notifyMatchFailure(binder.op, "region bind failure");
}

auto primIfOp = rewriter.create<Torch::PrimIfOp>(
binder.getLoc(), TypeRange(resultTypes), conditionBool);

auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
};
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
inlineIfCase(*elseRegion, primIfOp.getElseRegion());

auto replaceTerminator = [&](Region &region) {
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = region.front().getTerminator();
rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
terminator, terminator->getOperands());
};
replaceTerminator(primIfOp.getThenRegion());
replaceTerminator(primIfOp.getElseRegion());

rewriter.replaceOp(binder.op, primIfOp.getResults());
return success();
});
patterns.onOp("Less", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
12 changes: 4 additions & 8 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2562,16 +2562,12 @@
"_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic",
"_SoftmaxModule_basic",
# Failure - onnx_import
# Failure - onnx_lowering: onnx.AveragePool
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
# Failure - onnx_lowering: onnx.If
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",
"DiagonalModule_with_dims",
"DiagonalModule_with_dims_and_offset",
"DiagonalModule_with_negative_dims",
"DiagonalModule_with_offset",
# these diagonal modules are currently failing due to dynamic shape.
# We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead.
# when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here.
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
# Failure - onnx_lowering: onnx.MaxPool
Expand Down
31 changes: 31 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):
# ==============================================================================


class DiagonalWithStaticShapeModule(torch.nn.Module):
"""
Diagonal with static shape. The other diagonal modules are failing in onnx
because DecomoposeAtenEyeMOp requires constants n, m, which are only constant
when the shape is static.

Please remove this module and associated test once the issue is fixed.
"""

def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([5, 9], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diagonal(a)


@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 9))


# ==============================================================================


class DiagonalTransposedModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
8 changes: 7 additions & 1 deletion python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,14 @@ def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]):
continue
elif handler is False:
# Active error.
# try matching attribute type ID to name for a more descriptive error message
try:
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
except ValueError:
attr_type_name = "UNKNOWN"
raise OnnxImportError(
f"ONNX importer does not support generic node attribute type {attr_type}. "
f"ONNX importer does not support generic node attribute type {attr_type_name} "
f"with ID {attr_type}. "
f"This likely means that this is a special node which requires specific "
f"handling in the importer: {onnx_attr}"
)
Expand Down
20 changes: 20 additions & 0 deletions test/Conversion/TorchOnnxToTorch/ops/if.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s

// CHECK-LABEL: func.func @test_ifop_basic
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>)
// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32>
// CHECK-DAG: } else {
// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32>
func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> {
%1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}, {
%1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}
return %0 : !torch.vtensor<[1],f32>
}
Loading