Skip to content

Commit

Permalink
Shape Inference now succeeds for unimplemented ops (#2200)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger authored May 2, 2023
1 parent d2f4797 commit ec91b39
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 17 deletions.
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/ONNXDimAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ bool exploreSameInputDims(const onnx_mlir::DimAnalysis::DimT &dim,
// Get its shape interface.
onnx_mlir::ONNXOpShapeHelper *shapeHelper =
shape_op.getShapeHelper(op, {}, nullptr, nullptr);
if (!shapeHelper)
// If no shape helper, or unimplemented, just abort.
if (!shapeHelper || !shapeHelper->isImplemented())
return false;

// Compute shape.
Expand Down
13 changes: 2 additions & 11 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,12 @@
// Unsupported Operations
//===---------------------------------------------------------------------===//

// Operations for which shape inference has not been implemented yet
// If you add the implementation for one op, move it out of this section
// Also please add test case in test/mlir/onnx/onnx_shape_inference.mlir
// Followed by the implementation of lowering to Krnl and
// Enable the corresponding node test in check-onnx-backend

// Operations for which shape inference has not been implemented.
#define UNSUPPORTED_OPS(OP_TYPE) \
/* shape inference interface method */ \
mlir::LogicalResult mlir::OP_TYPE::inferShapes( \
std::function<void(mlir::Region &)> doShapeInference) { \
return emitOpError( \
"op is not supported at this time. Please open an issue on " \
"https://github.com/onnx/onnx-mlir and/or consider contributing " \
"code. " \
"Error encountered in shape inference."); \
return mlir::success(); \
}

#include "src/Dialect/ONNX/ONNXUnsupportedOps.hpp"
Expand Down
3 changes: 2 additions & 1 deletion src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ struct ONNXUnimplementedOpShapeHelper : public ONNXOpShapeHelper {
: ONNXOpShapeHelper(op, operands, ieBuilder, scope) {}
virtual ~ONNXUnimplementedOpShapeHelper() {}

mlir::LogicalResult computeShape() final { return mlir::failure(); }
bool isImplemented() override { return false; }
mlir::LogicalResult computeShape() final { return mlir::success(); }
};

// Classes for unsupported ops, including shape inference and shape helpers.
Expand Down
22 changes: 20 additions & 2 deletions src/Interface/ShapeHelperOpInterface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,18 @@ struct ONNXOpShapeHelper {
IndexExprScope *scope); /* Install local scope if null. */
virtual ~ONNXOpShapeHelper();

// Return true if implemented.
virtual bool isImplemented() { return true; }

// Every leaf class is expected to create a computeShape with the following
// signature. This method is responsible to compute at a minimum the output
// dims.
// Unimplemented operations return success, as these operations may be
// transformed later in a sequence of operations with implemented shape
// inference. To ensure an implementation, check the `isImplemented` function.
// This is used, for example, in dynamic analysis, where unimplemented shape
// inferences are simply ignored (and conservatively assume no knowledge about
// that operation's transfer function).
virtual mlir::LogicalResult computeShape() = 0;

// Compute shape and assert on failure.
Expand All @@ -105,8 +114,17 @@ struct ONNXOpShapeHelper {
mlir::ArrayRef<mlir::Attribute> encodingList = {});

// Get output dims for the N-th output dimension as Index Expressions.
// Scalar may have a DimsExpr that is empty.
DimsExpr &getOutputDims(int n = 0) { return privateOutputsDims[n]; }
// Scalar may have a DimsExpr that is empty. Requires an implementation.
DimsExpr &getOutputDims(int n = 0) {
if (!isImplemented()) {
llvm::errs() << "Implementation of shape helper for op " << op->getName()
<< "is not currently available; please open an issue on "
<< "\"https://github.com/onnx/onnx-mlir/\" and/or consider "
<< "contributing code if this op is required.\n";
llvm_unreachable("missing implementation for shape inference");
}
return privateOutputsDims[n];
}
// Set output dims, merging the dims associated with the current type with
// inferred dims provided here, as appropriate.
void setOutputDims(
Expand Down
7 changes: 5 additions & 2 deletions src/Interface/ShapeHelperOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ def ShapeHelperOpInterface : OpInterface<"ShapeHelperOpInterface"> {

For operations that do not support shape helpers at this stage, a
`ONNXUnimplementedOpShapeHelper` object is returned. This object does not
compute shapes, and simply return failure when `computeShape` is called
on it.
compute shapes, and simply return success when `computeShape` is called
on it. Users may verify if an operation has an actual implementation by
calling `isImplemented()` on the shape helper object. An implementation
is required when attempting to read the outputs of a shape helper object
via the `getOutputDims` method.

The new object is allocated on the heap and it is the responsability
of the object user to free the memory after last use.
Expand Down
22 changes: 22 additions & 0 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3547,3 +3547,25 @@ module {
// CHECK: }
}
}

// -----

// Check that ClipV6 operation shape inference goes through shape inference smoothly.
// ClipV6 has no shape inference as it is supposed to be first updated to the latest ClipOp.
// Using the latest shape inference, the default is to let unimplemented ops go through shape
// inference without asserts/failures. Asserts only occurs when the results of the shape
// inference is used.
// The output shoudl be the same as the input, as no shape inference is expected to be performed.

func.func @test_clipv6(%arg0: tensor<*xf32>) {
%0 = "onnx.ClipV6"(%arg0) {max = 6.000000e+00 : f32, min = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
return

// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_clipv6
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) {
// CHECK: [[VAR_0_:%.+]] = "onnx.ClipV6"([[PARAM_0_]]) {max = 6.000000e+00 : f32, min = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return
// CHECK: }
}

0 comments on commit ec91b39

Please sign in to comment.