diff --git a/.gitignore b/.gitignore index a73fe987ba3c1..dd44eb6d846c1 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,5 @@ onnxruntime_profile*.json /docs/python/examples/*.onnx /docs/python/examples/graph.* /docs/python/*_LICENSE +/csharp/**/obj/ +/csharp/**/bin/ diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index b7a747f76f311..e185cb6796cec 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -12,7 +12,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests { - public class InfereceTest + public class InferenceTest { [Fact] public void CanCreateAndDisposeSessionWithModelPath() @@ -113,8 +113,8 @@ private void ThrowWrongInputName() var container = new List(); container.Add(NamedOnnxValue.CreateFromTensor("wrong_name", tensor)); var ex = Assert.Throws(() => session.Run(container)); - Assert.Equal("[ErrorCode:InvalidArgument] Invalid Feed Input Names: wrong_name Valid input names are: data_0 ", ex.Message); - session.Dispose(); + Assert.Equal("[ErrorCode:InvalidArgument] Missing required inputs: data_0", ex.Message); + session.Dispose(); } [Fact] @@ -179,7 +179,7 @@ private void ThrowExtraInputs() container.Add(nov1); container.Add(nov2); var ex = Assert.Throws(() => session.Run(container)); - Assert.Equal("[ErrorCode:InvalidArgument] The number of feeds is not same as the number of the model input, expect 1 got 2", ex.Message); + Assert.StartsWith("[ErrorCode:InvalidArgument] Invalid Feed Input Names: extra. Valid input names are: ", ex.Message); session.Dispose(); } diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index b3621cb1e1680..d126fe5e5995a 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -115,7 +115,7 @@ void SessionState::AddInputNameToNodeInfoMapping(const std::string& input_name, common::Status SessionState::GetInputNodeInfo(const std::string& input_name, std::vector& node_info_vec) const { if (!input_names_to_nodeinfo_mapping_.count(input_name)) { - return Status(ONNXRUNTIME, FAIL, "Failed to find input name in the mapping"); + return Status(ONNXRUNTIME, FAIL, "Failed to find input name in the mapping: " + input_name); } node_info_vec = input_names_to_nodeinfo_mapping_.at(input_name); return Status::OK(); diff --git a/onnxruntime/core/framework/session_state_initializer.cc b/onnxruntime/core/framework/session_state_initializer.cc index 5cef936a71514..edceb317f771d 100644 --- a/onnxruntime/core/framework/session_state_initializer.cc +++ b/onnxruntime/core/framework/session_state_initializer.cc @@ -486,8 +486,7 @@ static bool IsArgNameInInputsOutputs(const std::string& name, common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph, const KernelRegistryManager& custom_registry_manager, SessionState& session_state) { - auto& weights_map = graph.GetAllInitializedTensors(); - auto& graph_inputs = graph.GetInputs(); + auto& graph_inputs = graph.GetInputsIncludingInitializers(); auto& graph_outputs = graph.GetOutputs(); for (auto& node : graph.Nodes()) { @@ -495,7 +494,7 @@ common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph onnxruntime::Node::ForEachWithIndex( node.InputDefs(), [&](const onnxruntime::NodeArg& arg, size_t index) { - if (arg.Name().empty() || weights_map.count(arg.Name())) { + if (arg.Name().empty()) { return Status::OK(); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8b580ae31155..6178e7810c562 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -421,10 +421,21 @@ class InferenceSession::Impl { } common::Status ValidateInputNames(const NameMLValMap& feeds) { - if (model_input_names_.size() != feeds.size()) { + std::string missing_required_inputs; + + std::for_each(required_model_input_names_.cbegin(), required_model_input_names_.cend(), + [&](const std::string& required_input) { + if (feeds.find(required_input) == feeds.cend()) { + if (!missing_required_inputs.empty()) + missing_required_inputs += ","; + + missing_required_inputs += required_input; + } + }); + + if (!missing_required_inputs.empty()) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "The number of feeds is not same as the number of the model input, expect ", - model_input_names_.size(), " got ", feeds.size()); + "Missing required inputs: ", missing_required_inputs); } bool valid = true; @@ -443,9 +454,9 @@ class InferenceSession::Impl { [&ostr](const std::string& elem) { ostr << elem << " "; }); - return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Invalid Feed Input Names:" + invalid_names.str() + - " Valid input names are: " + ostr.str()); + return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid Feed Input Names:", invalid_names.str(), + ". Valid input names are: ", ostr.str()); } return Status::OK(); @@ -804,7 +815,7 @@ class InferenceSession::Impl { } } - return std::make_pair(common::Status::OK(), &input_def_list_); + return std::make_pair(common::Status::OK(), &required_input_def_list_); } std::pair GetModelOutputs() const { @@ -896,28 +907,33 @@ class InferenceSession::Impl { model_metadata_.custom_metadata_map = model.MetaData(); model_metadata_.graph_name = graph.Name(); - // save inputs - auto& inputs = graph.GetInputs(); // inputs excluding initializers - input_def_list_.reserve(inputs.size()); - for (const auto& elem : inputs) { - if (!elem) { - return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null input nodearg ptr"); - } + // save required inputs + const auto& required_inputs = graph.GetInputs(); // inputs excluding initializers + required_input_def_list_.reserve(required_inputs.size()); + required_model_input_names_.reserve(required_inputs.size()); + for (const auto& elem : required_inputs) { + required_input_def_list_.push_back(elem); + required_model_input_names_.insert(elem->Name()); + } + // save all valid inputs + const auto& all_inputs = graph.GetInputsIncludingInitializers(); + input_def_list_.reserve(all_inputs.size()); + model_input_names_.reserve(all_inputs.size()); + for (const auto& elem : all_inputs) { input_def_list_.push_back(elem); model_input_names_.insert(elem->Name()); } // save outputs - auto& outputs = graph.GetOutputs(); + const auto& outputs = graph.GetOutputs(); output_def_list_.reserve(outputs.size()); + model_output_names_.reserve(outputs.size()); for (const auto& elem : outputs) { - if (!elem) { - return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null output nodearg ptr"); - } output_def_list_.push_back(elem); model_output_names_.insert(elem->Name()); } + VLOGS(*session_logger_, 1) << "Done saving model metadata"; return common::Status::OK(); } @@ -1030,10 +1046,12 @@ class InferenceSession::Impl { SessionState session_state_; ModelMetadata model_metadata_; + InputDefList required_input_def_list_; InputDefList input_def_list_; OutputDefList output_def_list_; // names of model inputs and outputs used for quick validation. + std::unordered_set required_model_input_names_; std::unordered_set model_input_names_; std::unordered_set model_output_names_; diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index f6209e0a62450..e83dcaddeb1a8 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -25,6 +25,7 @@ #include "core/session/IOBinding.h" #include "test/capturing_sink.h" #include "test/test_environment.h" +#include "test/providers/provider_test_utils.h" #include "test_utils.h" #include "gtest/gtest.h" @@ -808,6 +809,130 @@ TEST(InferenceSessionTests, ModelWithoutOpset) { } } +static ONNX_NAMESPACE::ModelProto CreateModelWithOptionalInputs() { + Model model("ModelWithOptionalInputs"); + auto& graph = model.MainGraph(); + + // create an initializer, which is an optional input that can be overridden + onnx::TensorProto tensor_proto; + tensor_proto.add_dims(1); + tensor_proto.set_data_type(TensorProto_DataType_FLOAT); + tensor_proto.add_float_data(1.f); + tensor_proto.set_name("optional_input"); + + graph.AddInitializedTensor(tensor_proto); + + TypeProto single_float; + single_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + single_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + auto& required_input = graph.GetOrCreateNodeArg("required_input", &single_float); + auto& optional_input = graph.GetOrCreateNodeArg("optional_input", nullptr); + auto& add_output = graph.GetOrCreateNodeArg("add_output", &single_float); + + EXPECT_TRUE(optional_input.Shape() != nullptr) << "AddInitializedTensor should have created the NodeArg with shape."; + + graph.AddNode("add", "Add", "Add required and optional inputs", {&required_input, &optional_input}, {&add_output}); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + auto model_proto = model.ToProto(); + + return model_proto; +} + +static common::Status RunOptionalInputTest(bool add_required_input, + bool add_optional_input, + bool add_invalid_input) { + auto model_proto = CreateModelWithOptionalInputs(); + + SessionOptions so; + so.session_logid = "InferenceSessionTests.TestOptionalInputs"; + + InferenceSession session_object{so, &DefaultLoggingManager()}; + + std::stringstream s1; + model_proto.SerializeToOstream(&s1); + auto status = session_object.Load(s1); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = session_object.Initialize(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + + RunOptions run_options; + run_options.run_tag = so.session_logid; + + // prepare inputs + std::vector dims = {1}; + std::vector required_input_val = {1.f}; + std::vector optional_input_val = {10.f}; // override initializer value of 1 + std::vector unknown_input_val = {20.f}; + + MLValue required_input_mlvalue; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault), + dims, required_input_val, &required_input_mlvalue); + + MLValue optional_input_mlvalue; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault), + dims, optional_input_val, &optional_input_mlvalue); + + MLValue unknown_input_mlvalue; + CreateMLValue(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault), + dims, unknown_input_val, &unknown_input_mlvalue); + + NameMLValMap feeds; + + if (add_required_input) + feeds.insert(std::make_pair("required_input", required_input_mlvalue)); + + if (add_optional_input) + feeds.insert(std::make_pair("optional_input", optional_input_mlvalue)); + + if (add_invalid_input) + feeds.insert(std::make_pair("unknown_input", unknown_input_mlvalue)); + + // prepare outputs + std::vector output_names; + output_names.push_back("add_output"); + std::vector fetches; + + float expected_value = required_input_val[0]; + expected_value += add_optional_input ? optional_input_val[0] : 1.f; + + status = session_object.Run(run_options, feeds, output_names, &fetches); + + if (status.IsOK()) { + MLValue& output = fetches.front(); + const auto& tensor = output.Get(); + float output_value = *tensor.Data(); + if (output_value != expected_value) { + status = ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output of ", output_value, " != ", expected_value); + } + } + + return status; +} + +TEST(InferenceSessionTests, TestOptionalInputs) { + // required input only + auto status = RunOptionalInputTest(true, false, false); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // required and optional input + status = RunOptionalInputTest(true, true, false); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // required, optional and invalid input + status = RunOptionalInputTest(true, true, true); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Names: unknown_input")); + + // missing required + status = RunOptionalInputTest(false, true, false); + ASSERT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing required inputs: required_input")); +} + TEST(ExecutionProviderTest, FunctionTest) { onnxruntime::Model model("graph_1"); auto& graph = model.MainGraph();