Skip to content

Commit

Permalink
Support onnx.If (#2825)
Browse files Browse the repository at this point in the history
This is probably a decent PR for learning about blocks and regions.

If you're here to learn about that, consider also looking at
lib/Conversion/TorchToSCF/TorchToSCF.cpp

While this doesn't include an e2e test, it is tested downstream in
https://github.com/nod-ai/SHARK-TestSuite/blob/main/e2eshark/onnx/operators/If/model.py

---------

Co-authored-by: Xida Ren <[email protected]>
  • Loading branch information
renxida and Xida Ren authored Apr 30, 2024
1 parent 315dc6c commit 33eef15
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 9 deletions.
25 changes: 25 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,31 @@ struct OpBinder {
return success();
}

ParseResult tensorResultTypes(llvm::SmallVector<mlir::Type> &typeList) {
for (auto result : op->getResults()) {
auto t = toValidTensorType(result.getType());
if (!t)
return failure();
typeList.push_back(t);
}
return success();
}

// The importer imports Onnx.GraphProto attributes as regions attached to the
// op.
ParseResult getRegionAtIndex(mlir::Region *&region, int64_t idx) {
if (idx >= op->getNumRegions())
return failure();

region = &op->getRegion(idx);

if (region == nullptr) {
return failure();
}

return success();
}

ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
int64_t idx) {
if (idx >= op->getNumResults())
Expand Down
54 changes: 54 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,60 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
alignCorners);
return success();
});
patterns.onOp(
"If", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value conditionTensor;
if (binder.tensorOperand(conditionTensor)) {
return rewriter.notifyMatchFailure(binder.op,
"condition bind failure");
}

auto conditionType =
conditionTensor.getType().cast<Torch::ValueTensorType>();
if (!conditionType || conditionType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(
binder.op, "condition must have one single element per "
"https://onnx.ai/onnx/operators/onnx__If.html");
auto conditionInt = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
conditionTensor);
auto conditionBool = rewriter.create<Torch::AtenBoolIntOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(), conditionInt);

llvm::SmallVector<mlir::Type> resultTypes;
if (binder.tensorResultTypes(resultTypes)) {
return rewriter.notifyMatchFailure(binder.op,
"result type bind failure");
}

Region *thenRegion, *elseRegion;
if (binder.getRegionAtIndex(elseRegion, 0) ||
binder.getRegionAtIndex(thenRegion, 1)) {
return rewriter.notifyMatchFailure(binder.op, "region bind failure");
}

auto primIfOp = rewriter.create<Torch::PrimIfOp>(
binder.getLoc(), TypeRange(resultTypes), conditionBool);

auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) {
rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin());
};
inlineIfCase(*thenRegion, primIfOp.getThenRegion());
inlineIfCase(*elseRegion, primIfOp.getElseRegion());

auto replaceTerminator = [&](Region &region) {
PatternRewriter::InsertionGuard guard(rewriter);
Operation *terminator = region.front().getTerminator();
rewriter.setInsertionPoint(terminator);
rewriter.replaceOpWithNewOp<Torch::PrimIfYieldOp>(
terminator, terminator->getOperands());
};
replaceTerminator(primIfOp.getThenRegion());
replaceTerminator(primIfOp.getElseRegion());

rewriter.replaceOp(binder.op, primIfOp.getResults());
return success();
});
patterns.onOp("Less", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
12 changes: 4 additions & 8 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2562,16 +2562,12 @@
"_ConvolutionDeprecated2DCudnnModule_basic",
"_ConvolutionDeprecated2DDeterministicModule_basic",
"_SoftmaxModule_basic",
# Failure - onnx_import
# Failure - onnx_lowering: onnx.AveragePool
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
# Failure - onnx_lowering: onnx.If
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",
"DiagonalModule_with_dims",
"DiagonalModule_with_dims_and_offset",
"DiagonalModule_with_negative_dims",
"DiagonalModule_with_offset",
# these diagonal modules are currently failing due to dynamic shape.
# We are currently testing aten.diagonal using DiagonalWithStaticShapeModule instead.
# when the issue is fixed, please remove DiagonalWithStaticShapeModule as well as the xfails here.
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
# Failure - onnx_lowering: onnx.MaxPool
Expand Down
31 changes: 31 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,37 @@ def DiagonalModule_nonsquare(module, tu: TestUtils):
# ==============================================================================


class DiagonalWithStaticShapeModule(torch.nn.Module):
"""
Diagonal with static shape. The other diagonal modules are failing in onnx
because DecomoposeAtenEyeMOp requires constants n, m, which are only constant
when the shape is static.
Please remove this module and associated test once the issue is fixed.
"""

def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([5, 9], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.diagonal(a)


@register_test_case(module_factory=lambda: DiagonalWithStaticShapeModule())
def DiagonalWithStaticShapeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 9))


# ==============================================================================


class DiagonalTransposedModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
8 changes: 7 additions & 1 deletion python/torch_mlir/extras/onnx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,14 @@ def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]):
continue
elif handler is False:
# Active error.
# try matching attribute type ID to name for a more descriptive error message
try:
attr_type_name = onnx.AttributeProto.AttributeType.Name(attr_type)
except ValueError:
attr_type_name = "UNKNOWN"
raise OnnxImportError(
f"ONNX importer does not support generic node attribute type {attr_type}. "
f"ONNX importer does not support generic node attribute type {attr_type_name} "
f"with ID {attr_type}. "
f"This likely means that this is a special node which requires specific "
f"handling in the importer: {onnx_attr}"
)
Expand Down
20 changes: 20 additions & 0 deletions test/Conversion/TorchOnnxToTorch/ops/if.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s

// CHECK-LABEL: func.func @test_ifop_basic
// CHECK: %[[IF:.*]] = torch.prim.If %{{.*}} -> (!torch.vtensor<[1],f32>)
// CHECK-DAG: %[[SUB:.*]] = torch.aten.sub.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[SUB]] : !torch.vtensor<[1],f32>
// CHECK-DAG: } else {
// CHECK-DAG: %[[ADD:.*]] = torch.aten.add.Tensor %arg1, %arg2, %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
// CHECK-DAG: torch.prim.If.yield %[[ADD]] : !torch.vtensor<[1],f32>
func.func @test_ifop_basic(%arg0: !torch.vtensor<[1],i1>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "conditional_example", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = torch.operator "onnx.If"(%arg0) : (!torch.vtensor<[1],i1>) -> !torch.vtensor<[1],f32> {
%1 = torch.operator "onnx.Add"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}, {
%1 = torch.operator "onnx.Sub"(%arg1, %arg2) : (!torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>
torch.operator_terminator %1 : !torch.vtensor<[1],f32>
}
return %0 : !torch.vtensor<[1],f32>
}

0 comments on commit 33eef15

Please sign in to comment.