Skip to content

Commit

Permalink
Implement CustomOp Output Type Inference function (#19906)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
This change addresses the following issues with the current CustomOP
Output Type inference
- The function does not take into account optional inputs. When input is
absent the inference is silently aborted, and no output type is inferred
(P1 customer issue)
- Inferring output type based on the input type for multi-kernel custom
ops is done based on the latest in sequence kernel definition. There is
not an attempt made to match the kernel based on the input type.
- Inference is aborted when variadic inputs/outputs are detected when
the generated input/output names fail to obtain type constraints. This
is not immediately clear from the code, because custom op schema is not
available within the inference function.
- No error reporting.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Most of CustomOPs lack their own type and shape inference function as it
was recently introduced. For that reason, it is important to fix this.
This change is inspired by a customer issue.

This is a follow up on:
- #15184
- cbourjau/ort-custom-op#11
- microsoft/onnxruntime-extensions#451
  • Loading branch information
yuslepukhin authored Mar 18, 2024
1 parent 4d31076 commit a033df8
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 55 deletions.
18 changes: 16 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1837,14 +1837,28 @@ struct OrtApi {

/** \brief Used for custom operators, get an input of a kernel
*
* \see ::OrtCustomOp
* The function attempts fetches the input of the kernel. If the input is optional
* and not present, the function returns success and out is set to nullptr.
*
* \param[in] context ::OrtKernelContext instance
* \param[in] input index. See KernelContext_GetInputCount for boundaries check.
* \param[in, out] returns a ptr to OrtValue if the input is present
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index,
_Out_ const OrtValue** out);

/** \brief Used for custom operators, get an output of a kernel
*
* \see ::OrtCustomOp
* The function attempts fetches the output of the kernel. If the output is optional
* and not present, the function returns success and out is set to nullptr.
*
* \param[in] context ::OrtKernelContext instance
* \param[in] output index. See KernelContext_GetOutputCount for boundaries check.
* \param[in, out] returns a ptr to OrtValue if the output is present
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index,
_In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out);
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,11 @@ struct KernelContext {
explicit KernelContext(OrtKernelContext* context);
size_t GetInputCount() const;
size_t GetOutputCount() const;
// If input is optional and is not present, the method returns en empty ConstValue
// which can be compared to nullptr.
ConstValue GetInput(size_t index) const;
// If outout is optional and is not present, the method returns en empty UnownedValue
// which can be compared to nullptr.
UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
void* GetGPUComputeStream() const;
Expand Down
157 changes: 109 additions & 48 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1066,59 +1066,120 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o
return Status::OK();
}

void InferOutputTypes(const InlinedVector<const KernelDef*>& kernel_defs,
ONNX_NAMESPACE::InferenceContext& infer_ctx) {
for (const auto& kernel_def : kernel_defs) {
// This function attempts to do its best for older custom ops (most of them) who do not have
// they own type and shape inference function. However, it falls short in some cases, and we leave
// those for the user to handle in their own inference function.
static void InferOutputTypes(const ONNX_NAMESPACE::OpSchema& schema, gsl::span<const KernelDef* const> kernel_defs,
ONNX_NAMESPACE::InferenceContext& infer_ctx) {
const auto& inputs = schema.inputs();
const auto node_input_num = infer_ctx.getNumInputs();

const KernelDef* def_selected = nullptr;
bool is_variadic_input = false;
bool is_homogeneous_input = false;
int32_t output_propagate{0};

for (size_t kernel_index = 0;
kernel_index < kernel_defs.size() && def_selected == nullptr;
++kernel_index) {
const auto* kernel_def = kernel_defs[kernel_index];
const auto& type_constraints = kernel_def->TypeConstraints();
auto num_inputs = infer_ctx.getNumInputs();
bool matched = true;
ONNXTensorElementDataType undef = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
// first, make sure there is a constraint for every input
for (size_t i = 0; i < num_inputs && matched; ++i) {
auto input_name = "Input" + std::to_string(i);
auto input_type = infer_ctx.getInputType(i);
if (input_type) {
auto elem_type = static_cast<ONNXTensorElementDataType>(input_type->tensor_type().elem_type());
auto tc_iter = type_constraints.find(input_name);
if (tc_iter != type_constraints.end()) {
if (tc_iter->second.size() > 1) {
undef = elem_type;
} else if (tc_iter->second.size() != 1 ||
tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) {
matched = false;
def_selected = kernel_def;

for (size_t i = 0; i < node_input_num; ++i) {
const auto input_type = infer_ctx.getInputType(i);

// Guard against variadic parameter index
const size_t schema_input_index = (i < inputs.size()) ? i : inputs.size() - 1;
const auto& param = inputs[schema_input_index];
const auto& input_name = param.GetName();
if (input_type == nullptr) {
if (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Optional)
continue;

ORT_THROW("[CustomOP type inferencing error]: kernel Input: ", input_name,
" is absent, but not optional. Op : ", schema.Name());
}

is_variadic_input = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic);
is_homogeneous_input = param.GetIsHomogeneous();

if (!is_variadic_input || is_homogeneous_input) {
auto hit = type_constraints.find(input_name);
if (hit != type_constraints.end()) {
const auto& types = hit->second;
// For custom ops kernel constraints are never empty
assert(!types.empty());
if (!std::any_of(types.cbegin(), types.cend(),
[input_type](const DataTypeImpl* type) {
return type->IsCompatible(*input_type);
})) {
def_selected = nullptr;
output_propagate = 0;
break;
}

// If we have multiple types possible from the constraints,
// record the last type and use it to guess the output type if
// output may have different types. Works well for symmetric single input/outputs
// otherwise give up and let the user supply their own function
if (types.size() > 1) {
output_propagate = input_type->tensor_type().elem_type();
}
} else {
matched = false;
ORT_THROW("[CustomOP type inferencing error]: no type constraint found for input: ",
input_name, " Op: ", schema.Name());
}
} else {
matched = false;
}
} // for
// next, ensure that there is a constraint for every output
auto num_outputs = infer_ctx.getNumOutputs();
for (size_t i = 0; i < num_outputs && matched; i++) {
auto output_name = "Output" + std::to_string(i);
auto tc_iter = type_constraints.find(output_name);
if (tc_iter == type_constraints.end() || tc_iter->second.size() < 1) {
matched = false;
}
}
if (matched) {
for (size_t i = 0; i < num_outputs; i++) {
auto output_name = "Output" + std::to_string(i);
auto output_type = infer_ctx.getOutputType(i);
auto tc_iter = type_constraints.find(output_name);
if (tc_iter->second.size() > 1) {
output_type->mutable_tensor_type()->set_elem_type(undef);
} else {
output_type->mutable_tensor_type()->set_elem_type(
tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type());
}
}
}

if (def_selected == nullptr) {
ORT_THROW("[CustomOP type inferencing error]: no kernel def matches node inputs for Op: ", schema.Name());
}

const auto& outputs = schema.outputs();
const auto node_output_num = infer_ctx.getNumOutputs();
const auto& selected_type_constraints = def_selected->TypeConstraints();

for (size_t i = 0; i < node_output_num; ++i) {
auto output_type = infer_ctx.getOutputType(i);
// Account for variadic outputs
const size_t schema_output_index = (i < outputs.size()) ? i : outputs.size() - 1;
const auto& param = outputs[schema_output_index];
const auto& output_name = param.GetName();

const bool is_variadic_output = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic);
const bool is_homogeneous = param.GetIsHomogeneous();

// We give up on variadic non-homogeneous outputs
// Let the user handle it in their inference function
if (is_variadic_output && !is_homogeneous) {
break;
}

auto hit = selected_type_constraints.find(output_name);
if (hit != selected_type_constraints.end()) {
const auto& types = hit->second;
assert(!types.empty());

if (types.size() == 1) {
// Use the constraint type
output_type->mutable_tensor_type()->set_elem_type(
types[0]->GetTypeProto()->tensor_type().elem_type());
} else if (!is_variadic_input || is_homogeneous_input) {
// If not variadic or homogeneous, and there are multiple types possible, guess from the last input type
// as this works for symmetric varied single input/outputs
// otherwise give up and let the user supply their own function
output_type->mutable_tensor_type()->set_elem_type(output_propagate);
}
} else {
ORT_THROW("[CustomOP type inferencing error]: no type constraint found for output: ",
output_name, " Op: ", schema.Name());
}
}
}

#endif

common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domains,
Expand Down Expand Up @@ -1178,13 +1239,13 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
}

std::vector<ONNX_NAMESPACE::OpSchema> schemas;
for (auto schema_iter : schema_map) {
schemas.push_back(schema_iter.second);
InlinedVector<const KernelDef*> kernel_defs = std::move(kernel_def_map[schema_iter.first]);
for (auto& [name, schema] : schema_map) {
schemas.push_back(schema);
auto infer_fn = schemas.back().GetTypeAndShapeInferenceFunction();
ONNX_NAMESPACE::InferenceFunction extended_infer_fn =
[infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) {
InferOutputTypes(kernel_defs, infer_ctx);
[sch = schema, infer_fn = std::move(infer_fn),
kernel_defs = std::move(kernel_def_map[name])](ONNX_NAMESPACE::InferenceContext& infer_ctx) {
InferOutputTypes(sch, kernel_defs, infer_ctx);
if (infer_fn) {
infer_fn(infer_ctx);
}
Expand Down
93 changes: 91 additions & 2 deletions onnxruntime/test/framework/shape_inference_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
#include <unordered_map>

#include "gtest/gtest.h"
#include "core/common/span_utils.h"
#include "core/graph/model.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "test/framework/model_builder_utils.h"
#include "test/util/include/asserts.h"
#include "test/util/include/test_utils.h"
#include "test/util/include/inference_session_wrapper.h"
#include "test/test_environment.h"

using namespace ONNX_NAMESPACE;
using namespace std;

namespace onnxruntime {
namespace test {
Expand All @@ -22,7 +25,7 @@ class ShapeInferenceTest : public ::testing::Test {
protected:
onnxruntime::Model model_;
int node_count_;
std::unordered_map<string, std::unique_ptr<onnxruntime::NodeArg>> name_to_arg_;
std::unordered_map<std::string, std::unique_ptr<onnxruntime::NodeArg>> name_to_arg_;

public:
ShapeInferenceTest() : model_("Test", false, DefaultLoggingManager().DefaultLogger()), node_count_(0) {}
Expand Down Expand Up @@ -73,5 +76,91 @@ TEST_F(ShapeInferenceTest, BasicTest) {
CheckShapeEquality(InputShape(node), OutputShape(node));
}

namespace {
struct MyCustomKernelWithOptionalInput {
MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) {
}

OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const {
return nullptr;
}
};

struct MyCustomOpWithOptionalInput : Ort::CustomOpBase<MyCustomOpWithOptionalInput,
MyCustomKernelWithOptionalInput,
true> {
explicit MyCustomOpWithOptionalInput(const char* provider) : provider_(provider) {}

OrtStatusPtr CreateKernelV2(const OrtApi& /* api */, const OrtKernelInfo* info, void** kernel) const {
*kernel = new MyCustomKernelWithOptionalInput(info);
return nullptr;
};

const char* GetName() const { return "FooBar"; };
const char* GetExecutionProviderType() const { return provider_; };

size_t GetInputTypeCount() const { return 3; };
ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const {
// The second input (index == 1) is optional
if (index == 1)
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL;

return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}

size_t GetOutputTypeCount() const { return 1; };
ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; };
OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
}

private:
const char* provider_;
};

const ORTCHAR_T* const OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = ORT_TSTR("testdata/foo_bar_2.onnx");

} // namespace

// CustomOps Output type inference function quits if it
// encounters the an output that is optional and absent.
// It quits without any errors or logging. We want to make sure
// that inference proceeds for all of the outputs when absent optional inputs are present
TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) {
MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider};

const auto& env = GetEnvironment();

Ort::CustomOpDomain op_domain("test");
op_domain.Add(&custom_op);

std::initializer_list<OrtCustomOpDomain*> op_domains = {static_cast<OrtCustomOpDomain*>(op_domain)};

SessionOptions sess_opts;
sess_opts.inter_op_param.thread_pool_size = 1;
sess_opts.intra_op_param.thread_pool_size = 1;

InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2};
ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains)));

ASSERT_STATUS_OK(session.Load());
ASSERT_STATUS_OK(session.Initialize());

const onnxruntime::Model& model = session.GetModel();
const auto& graph = model.MainGraph();
const auto& nodes = graph.Nodes();
for (const auto& node : nodes) {
if (node.OpType() == "FooBar") {
// check inferred shapes
const auto* node_arg = node.OutputDefs()[0];
const auto* type_proto = node_arg->TypeAsProto();
ASSERT_NE(nullptr, type_proto);
ASSERT_EQ(ONNX_NAMESPACE::TypeProto::ValueCase::kTensorType, type_proto->value_case());
ASSERT_EQ(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, type_proto->tensor_type().elem_type());
}
}
}

} // namespace test
} // namespace onnxruntime
7 changes: 4 additions & 3 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ static constexpr PATH_TYPE MODEL_WITH_CUSTOM_MODEL_METADATA = TSTR("testdata/mod
static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/VariedInputCustomOp.onnx");
static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_3.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
static constexpr PATH_TYPE OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx");
static constexpr PATH_TYPE VARIADIC_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/custom_op_variadic_io.onnx");
static constexpr PATH_TYPE VARIADIC_UNDEF_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR(
"testdata/custom_op_variadic_undef_io.onnx");
Expand Down Expand Up @@ -1082,7 +1082,7 @@ TEST(CApiTest, invalid_variadic_input_homogeneity_custom_op) {
}
}

TEST(CApiTest, optional_input_output_custom_op_handler) {
TEST(CApiTest, optional_input_custom_op_handler) {
MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider};

// `MyCustomOpFooBar` defines a custom op with atmost 3 inputs and the second input is optional.
Expand Down Expand Up @@ -1147,7 +1147,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) {
{
std::vector<const char*> input_names = {"X1", "X2"};
ort_inputs.erase(ort_inputs.begin() + 2); // remove the last input in the container
Ort::Session session(*ort_env, OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2, session_options);
Ort::Session session(*ort_env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2, session_options);
auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
&output_name, 1);
ASSERT_EQ(ort_outputs.size(), 1u);
Expand All @@ -1166,6 +1166,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) {
}
}
}

TEST(CApiTest, custom_op_with_attributes_handler) {
MyCustomOpWithAttributes custom_op{onnxruntime::kCpuExecutionProvider};

Expand Down

0 comments on commit a033df8

Please sign in to comment.