Skip to content

Commit

Permalink
Split inferShapes and added verify to Flatten operator (llvm#1140)
Browse files Browse the repository at this point in the history
* Split inferShapes and added verifier to Flatten operators
Signed-off-by: Adrian Sion <[email protected]>
  • Loading branch information
adriansion authored Feb 14, 2022
1 parent 54941f9 commit d3159e2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2835,19 +2835,37 @@ LogicalResult ONNXSplitV11Op::inferShapes(
// Flatten
//===----------------------------------------------------------------------===//

static LogicalResult verify(ONNXFlattenOp op) {

if (!hasShapeAndRank(op.input())) {
return success();
}
auto inTy = op.input().getType().dyn_cast<ShapedType>();
if (!inTy) {
return success();
}

int64_t axisValue = op.axis();
auto inputShape = inTy.getShape();
int64_t inputRank = inputShape.size();

if (axisValue < -1 * inputRank || axisValue > inputRank) {
return op.emitError("ONNXFlattenOP: axis() value is out of range");
}

return success();
}

LogicalResult ONNXFlattenOp::inferShapes(
std::function<void(mlir::Region &)> doShapeInference) {
auto inTy = input().getType().dyn_cast<RankedTensorType>();
auto inTy = input().getType().dyn_cast_or_null<RankedTensorType>();
if (!inTy) {
return success();
}

int64_t axisValue = axis();
auto inputShape = inTy.getShape();
int64_t inputRank = inputShape.size();
if (axisValue < -1 * inputRank || axisValue > inputRank) {
return emitOpError("ONNXFlattenOP: axis() value is out of range");
}

SmallVector<int64_t, 2> dims;

Expand Down
1 change: 1 addition & 0 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,7 @@ def ONNXFlattenOp:ONNX_Op<"Flatten",
return {20};
}
}];
let verifier = [{ return ::verify(*this); }];
}

def ONNXFloorOp:ONNX_Op<"Floor",
Expand Down
10 changes: 10 additions & 0 deletions test/mlir/onnx/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,13 @@ func @test_concat_verifier_3(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x
%1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 1 : si64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
}

// -----

func @test_flatten_verifier_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
// expected-error @+1 {{ONNXFlattenOP: axis() value is out of range}}
%1 = "onnx.Flatten"(%arg0) { axis = 5 : si64} : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> ()
}

// -----
1 change: 1 addition & 0 deletions utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@
'Conv',
'DepthToSpace',
'Expand',
'Flatten',
'Hardmax',
'InstanceNormalization',
'Mod',
Expand Down

0 comments on commit d3159e2

Please sign in to comment.