diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 144ee1205ee1a..ffafbe1d4e5e8 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -1837,14 +1837,28 @@ struct OrtApi { /** \brief Used for custom operators, get an input of a kernel * - * \see ::OrtCustomOp + * The function attempts fetches the input of the kernel. If the input is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] input index. See KernelContext_GetInputCount for boundaries check. + * \param[in, out] returns a ptr to OrtValue if the input is present + * + * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out); /** \brief Used for custom operators, get an output of a kernel * - * \see ::OrtCustomOp + * The function attempts fetches the output of the kernel. If the output is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] output index. See KernelContext_GetOutputCount for boundaries check. + * \param[in, out] returns a ptr to OrtValue if the output is present + * + * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index ae4c4bef90c64..60540514fbfa6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2055,7 +2055,11 @@ struct KernelContext { explicit KernelContext(OrtKernelContext* context); size_t GetInputCount() const; size_t GetOutputCount() const; + // If input is optional and is not present, the method returns en empty ConstValue + // which can be compared to nullptr. ConstValue GetInput(size_t index) const; + // If outout is optional and is not present, the method returns en empty UnownedValue + // which can be compared to nullptr. UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; void* GetGPUComputeStream() const; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 7a233c57cfdf3..3f19b09dd30be 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -1040,59 +1040,120 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o return Status::OK(); } -void InferOutputTypes(const InlinedVector& kernel_defs, - ONNX_NAMESPACE::InferenceContext& infer_ctx) { - for (const auto& kernel_def : kernel_defs) { +// This function attempts to do its best for older custom ops (most of them) who do not have +// they own type and shape inference function. However, it falls short in some cases, and we leave +// those for the user to handle in their own inference function. +static void InferOutputTypes(const ONNX_NAMESPACE::OpSchema& schema, gsl::span kernel_defs, + ONNX_NAMESPACE::InferenceContext& infer_ctx) { + const auto& inputs = schema.inputs(); + const auto node_input_num = infer_ctx.getNumInputs(); + + const KernelDef* def_selected = nullptr; + bool is_variadic_input = false; + bool is_homogeneous_input = false; + int32_t output_propagate{0}; + + for (size_t kernel_index = 0; + kernel_index < kernel_defs.size() && def_selected == nullptr; + ++kernel_index) { + const auto* kernel_def = kernel_defs[kernel_index]; const auto& type_constraints = kernel_def->TypeConstraints(); - auto num_inputs = infer_ctx.getNumInputs(); - bool matched = true; - ONNXTensorElementDataType undef = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - // first, make sure there is a constraint for every input - for (size_t i = 0; i < num_inputs && matched; ++i) { - auto input_name = "Input" + std::to_string(i); - auto input_type = infer_ctx.getInputType(i); - if (input_type) { - auto elem_type = static_cast(input_type->tensor_type().elem_type()); - auto tc_iter = type_constraints.find(input_name); - if (tc_iter != type_constraints.end()) { - if (tc_iter->second.size() > 1) { - undef = elem_type; - } else if (tc_iter->second.size() != 1 || - tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) { - matched = false; + def_selected = kernel_def; + + for (size_t i = 0; i < node_input_num; ++i) { + const auto input_type = infer_ctx.getInputType(i); + + // Guard against variadic parameter index + const size_t schema_input_index = (i < inputs.size()) ? i : inputs.size() - 1; + const auto& param = inputs[schema_input_index]; + const auto& input_name = param.GetName(); + if (input_type == nullptr) { + if (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Optional) + continue; + + ORT_THROW("[CustomOP type inferencing error]: kernel Input: ", input_name, + " is absent, but not optional. Op : ", schema.Name()); + } + + is_variadic_input = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic); + is_homogeneous_input = param.GetIsHomogeneous(); + + if (!is_variadic_input || is_homogeneous_input) { + auto hit = type_constraints.find(input_name); + if (hit != type_constraints.end()) { + const auto& types = hit->second; + // For custom ops kernel constraints are never empty + assert(!types.empty()); + if (!std::any_of(types.cbegin(), types.cend(), + [input_type](const DataTypeImpl* type) { + return type->IsCompatible(*input_type); + })) { + def_selected = nullptr; + output_propagate = 0; + break; + } + + // If we have multiple types possible from the constraints, + // record the last type and use it to guess the output type if + // output may have different types. Works well for symmetric single input/outputs + // otherwise give up and let the user supply their own function + if (types.size() > 1) { + output_propagate = input_type->tensor_type().elem_type(); } } else { - matched = false; + ORT_THROW("[CustomOP type inferencing error]: no type constraint found for input: ", + input_name, " Op: ", schema.Name()); } - } else { - matched = false; - } - } // for - // next, ensure that there is a constraint for every output - auto num_outputs = infer_ctx.getNumOutputs(); - for (size_t i = 0; i < num_outputs && matched; i++) { - auto output_name = "Output" + std::to_string(i); - auto tc_iter = type_constraints.find(output_name); - if (tc_iter == type_constraints.end() || tc_iter->second.size() < 1) { - matched = false; } } - if (matched) { - for (size_t i = 0; i < num_outputs; i++) { - auto output_name = "Output" + std::to_string(i); - auto output_type = infer_ctx.getOutputType(i); - auto tc_iter = type_constraints.find(output_name); - if (tc_iter->second.size() > 1) { - output_type->mutable_tensor_type()->set_elem_type(undef); - } else { - output_type->mutable_tensor_type()->set_elem_type( - tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type()); - } - } + } + + if (def_selected == nullptr) { + ORT_THROW("[CustomOP type inferencing error]: no kernel def matches node inputs for Op: ", schema.Name()); + } + + const auto& outputs = schema.outputs(); + const auto node_output_num = infer_ctx.getNumOutputs(); + const auto& selected_type_constraints = def_selected->TypeConstraints(); + + for (size_t i = 0; i < node_output_num; ++i) { + auto output_type = infer_ctx.getOutputType(i); + // Account for variadic outputs + const size_t schema_output_index = (i < outputs.size()) ? i : outputs.size() - 1; + const auto& param = outputs[schema_output_index]; + const auto& output_name = param.GetName(); + + const bool is_variadic_output = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic); + const bool is_homogeneous = param.GetIsHomogeneous(); + + // We give up on variadic non-homogeneous outputs + // Let the user handle it in their inference function + if (is_variadic_output && !is_homogeneous) { break; } + + auto hit = selected_type_constraints.find(output_name); + if (hit != selected_type_constraints.end()) { + const auto& types = hit->second; + assert(!types.empty()); + + if (types.size() == 1) { + // Use the constraint type + output_type->mutable_tensor_type()->set_elem_type( + types[0]->GetTypeProto()->tensor_type().elem_type()); + } else if (!is_variadic_input || is_homogeneous_input) { + // If not variadic or homogeneous, and there are multiple types possible, guess from the last input type + // as this works for symmetric varied single input/outputs + // otherwise give up and let the user supply their own function + output_type->mutable_tensor_type()->set_elem_type(output_propagate); + } + } else { + ORT_THROW("[CustomOP type inferencing error]: no type constraint found for output: ", + output_name, " Op: ", schema.Name()); + } } } + #endif common::Status CreateCustomRegistry(gsl::span op_domains, @@ -1152,13 +1213,13 @@ common::Status CreateCustomRegistry(gsl::span op_domai } std::vector schemas; - for (auto schema_iter : schema_map) { - schemas.push_back(schema_iter.second); - InlinedVector kernel_defs = std::move(kernel_def_map[schema_iter.first]); + for (auto& [name, schema] : schema_map) { + schemas.push_back(schema); auto infer_fn = schemas.back().GetTypeAndShapeInferenceFunction(); ONNX_NAMESPACE::InferenceFunction extended_infer_fn = - [infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) { - InferOutputTypes(kernel_defs, infer_ctx); + [sch = schema, infer_fn = std::move(infer_fn), + kernel_defs = std::move(kernel_def_map[name])](ONNX_NAMESPACE::InferenceContext& infer_ctx) { + InferOutputTypes(sch, kernel_defs, infer_ctx); if (infer_fn) { infer_fn(infer_ctx); } diff --git a/onnxruntime/test/framework/shape_inference_test.cc b/onnxruntime/test/framework/shape_inference_test.cc index bfabcd567803b..f5258760eb20d 100644 --- a/onnxruntime/test/framework/shape_inference_test.cc +++ b/onnxruntime/test/framework/shape_inference_test.cc @@ -5,13 +5,16 @@ #include #include "gtest/gtest.h" +#include "core/common/span_utils.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_cxx_api.h" #include "test/framework/model_builder_utils.h" +#include "test/util/include/asserts.h" #include "test/util/include/test_utils.h" +#include "test/util/include/inference_session_wrapper.h" #include "test/test_environment.h" using namespace ONNX_NAMESPACE; -using namespace std; namespace onnxruntime { namespace test { @@ -22,7 +25,7 @@ class ShapeInferenceTest : public ::testing::Test { protected: onnxruntime::Model model_; int node_count_; - std::unordered_map> name_to_arg_; + std::unordered_map> name_to_arg_; public: ShapeInferenceTest() : model_("Test", false, DefaultLoggingManager().DefaultLogger()), node_count_(0) {} @@ -73,5 +76,91 @@ TEST_F(ShapeInferenceTest, BasicTest) { CheckShapeEquality(InputShape(node), OutputShape(node)); } +namespace { +struct MyCustomKernelWithOptionalInput { + MyCustomKernelWithOptionalInput(const OrtKernelInfo* /*info*/) { + } + + OrtStatusPtr ComputeV2(OrtKernelContext* /* context */) const { + return nullptr; + } +}; + +struct MyCustomOpWithOptionalInput : Ort::CustomOpBase { + explicit MyCustomOpWithOptionalInput(const char* provider) : provider_(provider) {} + + OrtStatusPtr CreateKernelV2(const OrtApi& /* api */, const OrtKernelInfo* info, void** kernel) const { + *kernel = new MyCustomKernelWithOptionalInput(info); + return nullptr; + }; + + const char* GetName() const { return "FooBar"; }; + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return 3; }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const { + // The second input (index == 1) is optional + if (index == 1) + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; + + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + size_t GetOutputTypeCount() const { return 1; }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; }; + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + private: + const char* provider_; +}; + +const ORTCHAR_T* const OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = ORT_TSTR("testdata/foo_bar_2.onnx"); + +} // namespace + +// CustomOps Output type inference function quits if it +// encounters the an output that is optional and absent. +// It quits without any errors or logging. We want to make sure +// that inference proceeds for all of the outputs when absent optional inputs are present +TEST(ShapeInferenceCustomOpTest, custom_op_optional_input_inference_test) { + MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider}; + + const auto& env = GetEnvironment(); + + Ort::CustomOpDomain op_domain("test"); + op_domain.Add(&custom_op); + + std::initializer_list op_domains = {static_cast(op_domain)}; + + SessionOptions sess_opts; + sess_opts.inter_op_param.thread_pool_size = 1; + sess_opts.intra_op_param.thread_pool_size = 1; + + InferenceSessionWrapper session{sess_opts, env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2}; + ASSERT_STATUS_OK(session.AddCustomOpDomains(AsSpan(op_domains))); + + ASSERT_STATUS_OK(session.Load()); + ASSERT_STATUS_OK(session.Initialize()); + + const onnxruntime::Model& model = session.GetModel(); + const auto& graph = model.MainGraph(); + const auto& nodes = graph.Nodes(); + for (const auto& node : nodes) { + if (node.OpType() == "FooBar") { + // check inferred shapes + const auto* node_arg = node.OutputDefs()[0]; + const auto* type_proto = node_arg->TypeAsProto(); + ASSERT_NE(nullptr, type_proto); + ASSERT_EQ(ONNX_NAMESPACE::TypeProto::ValueCase::kTensorType, type_proto->value_case()); + ASSERT_EQ(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, type_proto->tensor_type().elem_type()); + } + } +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 453b5fdd360bf..987611af212bc 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -208,7 +208,7 @@ static constexpr PATH_TYPE MODEL_WITH_CUSTOM_MODEL_METADATA = TSTR("testdata/mod static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/VariedInputCustomOp.onnx"); static constexpr PATH_TYPE VARIED_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_3.onnx"); static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_bar_1.onnx"); -static constexpr PATH_TYPE OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx"); +static constexpr PATH_TYPE OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2 = TSTR("testdata/foo_bar_2.onnx"); static constexpr PATH_TYPE VARIADIC_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR("testdata/custom_op_variadic_io.onnx"); static constexpr PATH_TYPE VARIADIC_UNDEF_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI = TSTR( "testdata/custom_op_variadic_undef_io.onnx"); @@ -1082,7 +1082,7 @@ TEST(CApiTest, invalid_variadic_input_homogeneity_custom_op) { } } -TEST(CApiTest, optional_input_output_custom_op_handler) { +TEST(CApiTest, optional_input_custom_op_handler) { MyCustomOpWithOptionalInput custom_op{onnxruntime::kCpuExecutionProvider}; // `MyCustomOpFooBar` defines a custom op with atmost 3 inputs and the second input is optional. @@ -1147,7 +1147,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) { { std::vector input_names = {"X1", "X2"}; ort_inputs.erase(ort_inputs.begin() + 2); // remove the last input in the container - Ort::Session session(*ort_env, OPTIONAL_INPUT_OUTPUT_CUSTOM_OP_MODEL_URI_2, session_options); + Ort::Session session(*ort_env, OPTIONAL_INPUT_CUSTOM_OP_MODEL_URI_2, session_options); auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), &output_name, 1); ASSERT_EQ(ort_outputs.size(), 1u); @@ -1166,6 +1166,7 @@ TEST(CApiTest, optional_input_output_custom_op_handler) { } } } + TEST(CApiTest, custom_op_with_attributes_handler) { MyCustomOpWithAttributes custom_op{onnxruntime::kCpuExecutionProvider};