From a033df8c31311b6710570a3b7103dd8c2f9f9a64 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 18 Mar 2024 10:28:39 -0700 Subject: [PATCH] Implement CustomOp Output Type Inference function (#19906) ### Description This change addresses the following issues with the current CustomOP Output Type inference - The function does not take into account optional inputs. When input is absent the inference is silently aborted, and no output type is inferred (P1 customer issue) - Inferring output type based on the input type for multi-kernel custom ops is done based on the latest in sequence kernel definition. There is not an attempt made to match the kernel based on the input type. - Inference is aborted when variadic inputs/outputs are detected when the generated input/output names fail to obtain type constraints. This is not immediately clear from the code, because custom op schema is not available within the inference function. - No error reporting. ### Motivation and Context Most of CustomOPs lack their own type and shape inference function as it was recently introduced. For that reason, it is important to fix this. This change is inspired by a customer issue. This is a follow up on: - https://github.com/microsoft/onnxruntime/pull/15184 - https://github.com/cbourjau/ort-custom-op/pull/11 - https://github.com/microsoft/onnxruntime-extensions/issues/451 --- .../core/session/onnxruntime_c_api.h | 18 +- .../core/session/onnxruntime_cxx_api.h | 4 + onnxruntime/core/session/custom_ops.cc | 157 ++++++++++++------ .../test/framework/shape_inference_test.cc | 93 ++++++++++- onnxruntime/test/shared_lib/test_inference.cc | 7 +- 5 files changed, 224 insertions(+), 55 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cef50163f68b0..41b034e9c1dcc 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -1837,14 +1837,28 @@ struct OrtApi { /** \brief Used for custom operators, get an input of a kernel * - * \see ::OrtCustomOp + * The function attempts fetches the input of the kernel. If the input is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] input index. See KernelContext_GetInputCount for boundaries check. + * \param[in, out] returns a ptr to OrtValue if the input is present + * + * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); /** \brief Used for custom operators, get an output of a kernel * - * \see ::OrtCustomOp + * The function attempts fetches the output of the kernel. If the output is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] output index. See KernelContext_GetOutputCount for boundaries check. + * \param[in, out] returns a ptr to OrtValue if the output is present + * + * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index ae4c4bef90c64..60540514fbfa6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2055,7 +2055,11 @@ struct KernelContext { explicit KernelContext(OrtKernelContext* context); size_t GetInputCount() const; size_t GetOutputCount() const; + // If input is optional and is not present, the method returns en empty ConstValue + // which can be compared to nullptr. ConstValue GetInput(size_t index) const; + // If outout is optional and is not present, the method returns en empty UnownedValue + // which can be compared to nullptr. UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; void* GetGPUComputeStream() const; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 6e9d68d259a5d..513aafcdadb7d 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -1066,59 +1066,120 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o return Status::OK(); } -void InferOutputTypes(const InlinedVector& kernel_defs, - ONNX_NAMESPACE::InferenceContext& infer_ctx) { - for (const auto& kernel_def : kernel_defs) { +// This function attempts to do its best for older custom ops (most of them) who do not have +// they own type and shape inference function. However, it falls short in some cases, and we leave +// those for the user to handle in their own inference function. +static void InferOutputTypes(const ONNX_NAMESPACE::OpSchema& schema, gsl::span kernel_defs, + ONNX_NAMESPACE::InferenceContext& infer_ctx) { + const auto& inputs = schema.inputs(); + const auto node_input_num = infer_ctx.getNumInputs(); + + const KernelDef* def_selected = nullptr; + bool is_variadic_input = false; + bool is_homogeneous_input = false; + int32_t output_propagate{0}; + + for (size_t kernel_index = 0; + kernel_index < kernel_defs.size() && def_selected == nullptr; + ++kernel_index) { + const auto* kernel_def = kernel_defs[kernel_index]; const auto& type_constraints = kernel_def->TypeConstraints(); - auto num_inputs = infer_ctx.getNumInputs(); - bool matched = true; - ONNXTensorElementDataType undef = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - // first, make sure there is a constraint for every input - for (size_t i = 0; i < num_inputs && matched; ++i) { - auto input_name = "Input" + std::to_string(i); - auto input_type = infer_ctx.getInputType(i); - if (input_type) { - auto elem_type = static_cast(input_type->tensor_type().elem_type()); - auto tc_iter = type_constraints.find(input_name); - if (tc_iter != type_constraints.end()) { - if (tc_iter->second.size() > 1) { - undef = elem_type; - } else if (tc_iter->second.size() != 1 || - tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) { - matched = false; + def_selected = kernel_def; + + for (size_t i = 0; i < node_input_num; ++i) { + const auto input_type = infer_ctx.getInputType(i); + + // Guard against variadic parameter index + const size_t schema_input_index = (i < inputs.size()) ? i : inputs.size() - 1; + const auto& param = inputs[schema_input_index]; + const auto& input_name = param.GetName(); + if (input_type == nullptr) { + if (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Optional) + continue; + + ORT_THROW("[CustomOP type inferencing error]: kernel Input: ", input_name, + " is absent, but not optional. Op : ", schema.Name()); + } + + is_variadic_input = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic); + is_homogeneous_input = param.GetIsHomogeneous(); + + if (!is_variadic_input || is_homogeneous_input) { + auto hit = type_constraints.find(input_name); + if (hit != type_constraints.end()) { + const auto& types = hit->second; + // For custom ops kernel constraints are never empty + assert(!types.empty()); + if (!std::any_of(types.cbegin(), types.cend(), + [input_type](const DataTypeImpl* type) { + return type->IsCompatible(*input_type); + })) { + def_selected = nullptr; + output_propagate = 0; + break; + } + + // If we have multiple types possible from the constraints, + // record the last type and use it to guess the output type if + // output may have different types. Works well for symmetric single input/outputs + // otherwise give up and let the user supply their own function + if (types.size() > 1) { + output_propagate = input_type->tensor_type().elem_type(); } } else { - matched = false; + ORT_THROW("[CustomOP type inferencing error]: no type constraint found for input: ", + input_name, " Op: ", schema.Name()); } - } else { - matched = false; - } - } // for - // next, ensure that there is a constraint for every output - auto num_outputs = infer_ctx.getNumOutputs(); - for (size_t i = 0; i < num_outputs && matched; i++) { - auto output_name = "Output" + std::to_string(i); - auto tc_iter = type_constraints.find(output_name); - if (tc_iter == type_constraints.end() || tc_iter->second.size() < 1) { - matched = false; } } - if (matched) { - for (size_t i = 0; i < num_outputs; i++) { - auto output_name = "Output" + std::to_string(i); - auto output_type = infer_ctx.getOutputType(i); - auto tc_iter = type_constraints.find(output_name); - if (tc_iter->second.size() > 1) { - output_type->mutable_tensor_type()->set_elem_type(undef); - } else { - output_type->mutable_tensor_type()->set_elem_type( - tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type()); - } - } + } + + if (def_selected == nullptr) { + ORT_THROW("[CustomOP type inferencing error]: no kernel def matches node inputs for Op: ", schema.Name()); + } + + const auto& outputs = schema.outputs(); + const auto node_output_num = infer_ctx.getNumOutputs(); + const auto& selected_type_constraints = def_selected->TypeConstraints(); + + for (size_t i = 0; i < node_output_num; ++i) { + auto output_type = infer_ctx.getOutputType(i); + // Account for variadic outputs + const size_t schema_output_index = (i < outputs.size()) ? i : outputs.size() - 1; + const auto& param = outputs[schema_output_index]; + const auto& output_name = param.GetName(); + + const bool is_variadic_output = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic); + const bool is_homogeneous = param.GetIsHomogeneous(); + + // We give up on variadic non-homogeneous outputs + // Let the user handle it in their inference function + if (is_variadic_output && !is_homogeneous) { break; } + + auto hit = selected_type_constraints.find(output_name); + if (hit != selected_type_constraints.end()) { + const auto& types = hit->second; + assert(!types.empty()); + + if (types.size() == 1) { + // Use the constraint type + output_type->mutable_tensor_type()->set_elem_type( + types[0]->GetTypeProto()->tensor_type().elem_type()); + } else if (!is_variadic_input || is_homogeneous_input) { + // If not variadic or homogeneous, and there are multiple types possible, guess from the last input type + // as this works for symmetric varied single input/outputs + // otherwise give up and let the user supply their own function + output_type->mutable_tensor_type()->set_elem_type(output_propagate); + } + } else { + ORT_THROW("[CustomOP type inferencing error]: no type constraint found for output: ", + output_name, " Op: ", schema.Name()); + } } } + #endif common::Status CreateCustomRegistry(gsl::span op_domains, @@ -1178,13 +1239,13 @@ common::Status CreateCustomRegistry(gsl::span op_domai } std::vector schemas; - for (auto schema_iter : schema_map) { - schemas.push_back(schema_iter.second); - InlinedVector kernel_defs = std::move(kernel_def_map[schema_iter.first]); + for (auto& [name, schema] : schema_map) { + schemas.push_back(schema); auto infer_fn = schemas.back().GetTypeAndShapeInferenceFunction(); ONNX_NAMESPACE::InferenceFunction extended_infer_fn = - [infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) { - InferOutputTypes(kernel_defs, infer_ctx); + [sch = schema, infer_fn = std::move(infer_fn), + kernel_defs = std::move(kernel_def_map[name])](ONNX_NAMESPACE::InferenceContext& infer_ctx) { + InferOutputTypes(sch, kernel_defs, infer_ctx); if (infer_fn) { infer_fn(infer_ctx); } diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index bfabcd567803b..f5258760eb20d 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -5,13 +5,16 @@ #include #include "gtest/gtest.h" +#include "core/common/span_utils.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_cxx_api.h" #include "test/framework/model_builder_utils.h" +#include "test/util/include/asserts.h" #include "test/util/include/test_utils.h" +#include "test/util/include/inference_session_wrapper.h" #include "test/test_environment.h" using namespace ONNX_NAMESPACE; -using namespace std; namespace onnxruntime { namespace test { @@ -22,7 +25,7 @@ class ShapeInferenceTest : public ::testing::Test { protected: onnxruntime::Model model_; int node_count_; - std::unordered_map> name_to_arg_; + std::unordered_map> name_to_arg_; public: ShapeInferenceTest() : model_("Test", false, DefaultLoggingManager().DefaultLogger()), node_count_(0) {} @@ -73,5 +76,91 @@ TEST_F(ShapeInferenceTest, BasicTest) { CheckShapeEquality(InputShape(node), OutputShape(node)); } +namespace { +struct MyCustomKernelWithOptionalInput { + MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) { + } + + OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const { + return nullptr; + } +}; + +struct MyCustomOpWithOptionalInput : Ort::CustomOpBase { + explicit MyCustomOpWithOptionalInput(const char* provider) : provider_(provider) {} + + OrtStatusPtr CreateKernelV2(const OrtApi& /* api */, const OrtKernelInfo* info, void** kernel) const { + *kernel = new MyCustomKernelWithOptionalInput(info); + return nullptr; + }; + + const char* GetName() const { return "FooBar"; }; + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return 3; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const { + // The second input (index == 1) is optional + if (index == 1) + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; + + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + private: + const char* provider_; +}; + +const ORTCHAR_T* const OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = ORT_TSTR("testdata/foo_bar_2.onnx"); + +} // namespace + +// CustomOps Output type inference function quits if it +// encounters the an output that is optional and absent. +// It quits without any errors or logging. We want to make sure +// that inference proceeds for all of the outputs when absent optional inputs are present +TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) { + MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider}; + + const auto& env = GetEnvironment(); + + Ort::CustomOpDomain op_domain("test"); + op_domain.Add(&custom_op); + + std::initializer_list op_domains = {static_cast(op_domain)}; + + SessionOptions sess_opts; + sess_opts.inter_op_param.thread_pool_size = 1; + sess_opts.intra_op_param.thread_pool_size = 1; + + InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2}; + ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains))); + + ASSERT_STATUS_OK(session.Load()); + ASSERT_STATUS_OK(session.Initialize()); + + const onnxruntime::Model& model = session.GetModel(); + const auto& graph = model.MainGraph(); + const auto& nodes = graph.Nodes(); + for (const auto& node : nodes) { + if (node.OpType() == "FooBar") { + // check inferred shapes + const auto* node_arg = node.OutputDefs()[0]; + const auto* type_proto = node_arg->TypeAsProto(); + ASSERT_NE(nullptr, type_proto); + ASSERT_EQ(ONNX_NAMESPACE::TypeProto::ValueCase::kTensorType, type_proto->value_case()); + ASSERT_EQ(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, type_proto->tensor_type().elem_type()); + } + } +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 91453102d406f..52dd2a84e383b 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -208,7 +208,7 @@ static constexpr PATH_TYPE MODEL_WITH_CUSTOM_MODEL_METADATA = TSTR("testdata/mod static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/VariedInputCustomOp.onnx"); static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_3.onnx"); static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx"); -static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx"); +static constexpr PATH_TYPE OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx"); static constexpr PATH_TYPE VARIADIC_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/custom_op_variadic_io.onnx"); static constexpr PATH_TYPE VARIADIC_UNDEF_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR( "testdata/custom_op_variadic_undef_io.onnx"); @@ -1082,7 +1082,7 @@ TEST(CApiTest, invalid_variadic_input_homogeneity_custom_op) { } } -TEST(CApiTest, optional_input_output_custom_op_handler) { +TEST(CApiTest, optional_input_custom_op_handler) { MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider}; // `MyCustomOpFooBar` defines a custom op with atmost 3 inputs and the second input is optional. @@ -1147,7 +1147,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) { { std::vector input_names = {"X1", "X2"}; ort_inputs.erase(ort_inputs.begin() + 2); // remove the last input in the container - Ort::Session session(*ort_env, OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2, session_options); + Ort::Session session(*ort_env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2, session_options); auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1); ASSERT_EQ(ort_outputs.size(), 1u); @@ -1166,6 +1166,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) { } } } + TEST(CApiTest, custom_op_with_attributes_handler) { MyCustomOpWithAttributes custom_op{onnxruntime::kCpuExecutionProvider};