Skip to content

Commit

Permalink
Merge pull request #57 from Microsoft/scmckay/FixInferenceSessionInpu…
Browse files Browse the repository at this point in the history
…tValidationHandlingOfOptionalInputs

Support overriding initializers via feed inputs
  • Loading branch information
skottmckay authored Nov 29, 2018
2 parents ca86d8f + 97dc949 commit a4bcb11
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@ onnxruntime_profile*.json
/docs/python/examples/*.onnx
/docs/python/examples/graph.*
/docs/python/*_LICENSE
/csharp/**/obj/
/csharp/**/bin/
8 changes: 4 additions & 4 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace Microsoft.ML.OnnxRuntime.Tests
{
public class InfereceTest
public class InferenceTest
{
[Fact]
public void CanCreateAndDisposeSessionWithModelPath()
Expand Down Expand Up @@ -113,8 +113,8 @@ private void ThrowWrongInputName()
var container = new List<NamedOnnxValue>();
container.Add(NamedOnnxValue.CreateFromTensor<float>("wrong_name", tensor));
var ex = Assert.Throws<OnnxRuntimeException>(() => session.Run(container));
Assert.Equal("[ErrorCode:InvalidArgument] Invalid Feed Input Names: wrong_name Valid input names are: data_0 ", ex.Message);
session.Dispose();
Assert.Equal("[ErrorCode:InvalidArgument] Missing required inputs: data_0", ex.Message);
session.Dispose();
}

[Fact]
Expand Down Expand Up @@ -179,7 +179,7 @@ private void ThrowExtraInputs()
container.Add(nov1);
container.Add(nov2);
var ex = Assert.Throws<OnnxRuntimeException>(() => session.Run(container));
Assert.Equal("[ErrorCode:InvalidArgument] The number of feeds is not same as the number of the model input, expect 1 got 2", ex.Message);
Assert.StartsWith("[ErrorCode:InvalidArgument] Invalid Feed Input Names: extra. Valid input names are: ", ex.Message);
session.Dispose();
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ void SessionState::AddInputNameToNodeInfoMapping(const std::string& input_name,

common::Status SessionState::GetInputNodeInfo(const std::string& input_name, std::vector<NodeInfo>& node_info_vec) const {
if (!input_names_to_nodeinfo_mapping_.count(input_name)) {
return Status(ONNXRUNTIME, FAIL, "Failed to find input name in the mapping");
return Status(ONNXRUNTIME, FAIL, "Failed to find input name in the mapping: " + input_name);
}
node_info_vec = input_names_to_nodeinfo_mapping_.at(input_name);
return Status::OK();
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/framework/session_state_initializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -486,16 +486,15 @@ static bool IsArgNameInInputsOutputs(const std::string& name,
common::Status SaveInputOutputNamesToNodeMapping(const onnxruntime::Graph& graph,
const KernelRegistryManager& custom_registry_manager,
SessionState& session_state) {
auto& weights_map = graph.GetAllInitializedTensors();
auto& graph_inputs = graph.GetInputs();
auto& graph_inputs = graph.GetInputsIncludingInitializers();
auto& graph_outputs = graph.GetOutputs();

for (auto& node : graph.Nodes()) {
ONNXRUNTIME_RETURN_IF_ERROR(
onnxruntime::Node::ForEachWithIndex(
node.InputDefs(),
[&](const onnxruntime::NodeArg& arg, size_t index) {
if (arg.Name().empty() || weights_map.count(arg.Name())) {
if (arg.Name().empty()) {
return Status::OK();
}

Expand Down
54 changes: 36 additions & 18 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,21 @@ class InferenceSession::Impl {
}

common::Status ValidateInputNames(const NameMLValMap& feeds) {
if (model_input_names_.size() != feeds.size()) {
std::string missing_required_inputs;

std::for_each(required_model_input_names_.cbegin(), required_model_input_names_.cend(),
[&](const std::string& required_input) {
if (feeds.find(required_input) == feeds.cend()) {
if (!missing_required_inputs.empty())
missing_required_inputs += ",";

missing_required_inputs += required_input;
}
});

if (!missing_required_inputs.empty()) {
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"The number of feeds is not same as the number of the model input, expect ",
model_input_names_.size(), " got ", feeds.size());
"Missing required inputs: ", missing_required_inputs);
}

bool valid = true;
Expand All @@ -443,9 +454,9 @@ class InferenceSession::Impl {
[&ostr](const std::string& elem) {
ostr << elem << " ";
});
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Invalid Feed Input Names:" + invalid_names.str() +
" Valid input names are: " + ostr.str());
return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Invalid Feed Input Names:", invalid_names.str(),
". Valid input names are: ", ostr.str());
}

return Status::OK();
Expand Down Expand Up @@ -804,7 +815,7 @@ class InferenceSession::Impl {
}
}

return std::make_pair(common::Status::OK(), &input_def_list_);
return std::make_pair(common::Status::OK(), &required_input_def_list_);
}

std::pair<common::Status, const OutputDefList*> GetModelOutputs() const {
Expand Down Expand Up @@ -896,28 +907,33 @@ class InferenceSession::Impl {
model_metadata_.custom_metadata_map = model.MetaData();
model_metadata_.graph_name = graph.Name();

// save inputs
auto& inputs = graph.GetInputs(); // inputs excluding initializers
input_def_list_.reserve(inputs.size());
for (const auto& elem : inputs) {
if (!elem) {
return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null input nodearg ptr");
}
// save required inputs
const auto& required_inputs = graph.GetInputs(); // inputs excluding initializers
required_input_def_list_.reserve(required_inputs.size());
required_model_input_names_.reserve(required_inputs.size());
for (const auto& elem : required_inputs) {
required_input_def_list_.push_back(elem);
required_model_input_names_.insert(elem->Name());
}

// save all valid inputs
const auto& all_inputs = graph.GetInputsIncludingInitializers();
input_def_list_.reserve(all_inputs.size());
model_input_names_.reserve(all_inputs.size());
for (const auto& elem : all_inputs) {
input_def_list_.push_back(elem);
model_input_names_.insert(elem->Name());
}

// save outputs
auto& outputs = graph.GetOutputs();
const auto& outputs = graph.GetOutputs();
output_def_list_.reserve(outputs.size());
model_output_names_.reserve(outputs.size());
for (const auto& elem : outputs) {
if (!elem) {
return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null output nodearg ptr");
}
output_def_list_.push_back(elem);
model_output_names_.insert(elem->Name());
}

VLOGS(*session_logger_, 1) << "Done saving model metadata";
return common::Status::OK();
}
Expand Down Expand Up @@ -1030,10 +1046,12 @@ class InferenceSession::Impl {
SessionState session_state_;

ModelMetadata model_metadata_;
InputDefList required_input_def_list_;
InputDefList input_def_list_;
OutputDefList output_def_list_;

// names of model inputs and outputs used for quick validation.
std::unordered_set<std::string> required_model_input_names_;
std::unordered_set<std::string> model_input_names_;
std::unordered_set<std::string> model_output_names_;

Expand Down
125 changes: 125 additions & 0 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "core/session/IOBinding.h"
#include "test/capturing_sink.h"
#include "test/test_environment.h"
#include "test/providers/provider_test_utils.h"
#include "test_utils.h"
#include "gtest/gtest.h"

Expand Down Expand Up @@ -808,6 +809,130 @@ TEST(InferenceSessionTests, ModelWithoutOpset) {
}
}

static ONNX_NAMESPACE::ModelProto CreateModelWithOptionalInputs() {
Model model("ModelWithOptionalInputs");
auto& graph = model.MainGraph();

// create an initializer, which is an optional input that can be overridden
onnx::TensorProto tensor_proto;
tensor_proto.add_dims(1);
tensor_proto.set_data_type(TensorProto_DataType_FLOAT);
tensor_proto.add_float_data(1.f);
tensor_proto.set_name("optional_input");

graph.AddInitializedTensor(tensor_proto);

TypeProto single_float;
single_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
single_float.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);

auto& required_input = graph.GetOrCreateNodeArg("required_input", &single_float);
auto& optional_input = graph.GetOrCreateNodeArg("optional_input", nullptr);
auto& add_output = graph.GetOrCreateNodeArg("add_output", &single_float);

EXPECT_TRUE(optional_input.Shape() != nullptr) << "AddInitializedTensor should have created the NodeArg with shape.";

graph.AddNode("add", "Add", "Add required and optional inputs", {&required_input, &optional_input}, {&add_output});

auto status = graph.Resolve();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();

auto model_proto = model.ToProto();

return model_proto;
}

static common::Status RunOptionalInputTest(bool add_required_input,
bool add_optional_input,
bool add_invalid_input) {
auto model_proto = CreateModelWithOptionalInputs();

SessionOptions so;
so.session_logid = "InferenceSessionTests.TestOptionalInputs";

InferenceSession session_object{so, &DefaultLoggingManager()};

std::stringstream s1;
model_proto.SerializeToOstream(&s1);
auto status = session_object.Load(s1);
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
status = session_object.Initialize();
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();

RunOptions run_options;
run_options.run_tag = so.session_logid;

// prepare inputs
std::vector<int64_t> dims = {1};
std::vector<float> required_input_val = {1.f};
std::vector<float> optional_input_val = {10.f}; // override initializer value of 1
std::vector<float> unknown_input_val = {20.f};

MLValue required_input_mlvalue;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault),
dims, required_input_val, &required_input_mlvalue);

MLValue optional_input_mlvalue;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault),
dims, optional_input_val, &optional_input_mlvalue);

MLValue unknown_input_mlvalue;
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, ONNXRuntimeMemTypeDefault),
dims, unknown_input_val, &unknown_input_mlvalue);

NameMLValMap feeds;

if (add_required_input)
feeds.insert(std::make_pair("required_input", required_input_mlvalue));

if (add_optional_input)
feeds.insert(std::make_pair("optional_input", optional_input_mlvalue));

if (add_invalid_input)
feeds.insert(std::make_pair("unknown_input", unknown_input_mlvalue));

// prepare outputs
std::vector<std::string> output_names;
output_names.push_back("add_output");
std::vector<MLValue> fetches;

float expected_value = required_input_val[0];
expected_value += add_optional_input ? optional_input_val[0] : 1.f;

status = session_object.Run(run_options, feeds, output_names, &fetches);

if (status.IsOK()) {
MLValue& output = fetches.front();
const auto& tensor = output.Get<Tensor>();
float output_value = *tensor.Data<float>();
if (output_value != expected_value) {
status = ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output of ", output_value, " != ", expected_value);
}
}

return status;
}

TEST(InferenceSessionTests, TestOptionalInputs) {
// required input only
auto status = RunOptionalInputTest(true, false, false);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();

// required and optional input
status = RunOptionalInputTest(true, true, false);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();

// required, optional and invalid input
status = RunOptionalInputTest(true, true, true);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid Feed Input Names: unknown_input"));

// missing required
status = RunOptionalInputTest(false, true, false);
ASSERT_FALSE(status.IsOK());
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing required inputs: required_input"));
}

TEST(ExecutionProviderTest, FunctionTest) {
onnxruntime::Model model("graph_1");
auto& graph = model.MainGraph();
Expand Down

0 comments on commit a4bcb11

Please sign in to comment.