Skip to content

Commit

Permalink
Allow for an optional subgraph input to have no type info. (#10379)
Browse files Browse the repository at this point in the history
Add a test for a missing optional input to Loop.
  • Loading branch information
skottmckay authored Jan 29, 2022
1 parent 85cbe83 commit baa1767
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
9 changes: 7 additions & 2 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2131,11 +2131,16 @@ Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,

// apply type/shape info to the subgraph's inputs
for (size_t i = 0; i < num_subgraph_inputs; ++i) {
const auto& input_type = *input_types[i];
const auto* input_type = input_types[i];
if (input_type == nullptr) {
// optional input
continue;
}

const auto& subgraph_input = *subgraph_inputs->at(i);

NodeArg* mutable_nodearg = subgraph.GetNodeArg(subgraph_input.Name());
status = mutable_nodearg->UpdateTypeAndShape(input_type, true, options.override_types, subgraph.logger_);
status = mutable_nodearg->UpdateTypeAndShape(*input_type, true, options.override_types, subgraph.logger_);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node:", node.Name(), " ", status.ErrorMessage());
}
Expand Down
38 changes: 19 additions & 19 deletions onnxruntime/test/providers/cpu/controlflow/loop_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ TEST(Loop, SubgraphTypeOverride) {
std::vector<NodeArg*> inputs;
std::vector<NodeArg*> outputs;

/*
/*
Inputs: iter_num, cond_in, fake_in, loop carried state variables.
iter_num_in cond_in fake_in [outer_scope_0]
Expand Down Expand Up @@ -671,7 +671,7 @@ TEST(Loop, SubgraphTypeOverride) {
LoopOpTester test{{}, create_subgraph};

test.AddInput<int64_t>("M", {1}, {1});
test.AddInput<bool>("cond", {1}, {true});
test.AddOptionalInputEdge<bool>(); // 'cond' is optional in this test so don't provide it
test.AddInput<double>("fake", {1}, {0.f});
test.AddInput<double>("outer_scope_0", {1}, {kOuterNodeAddValue});

Expand Down Expand Up @@ -799,8 +799,8 @@ TEST(Loop, Opset11WithNoVariadicInputsAndOutputs) {

auto* constant_attribute_tensor_proto = attr_proto.mutable_t();
constant_attribute_tensor_proto->mutable_dims()->Clear(); // scalar
constant_attribute_tensor_proto->set_data_type(TensorProto_DataType_FLOAT); //float scalar
*constant_attribute_tensor_proto->mutable_float_data()->Add() = 1.0f; //float scalar with value 1.0f
constant_attribute_tensor_proto->set_data_type(TensorProto_DataType_FLOAT); // float scalar
*constant_attribute_tensor_proto->mutable_float_data()->Add() = 1.0f; // float scalar with value 1.0f

constant_node.AddAttribute("value", attr_proto);
}
Expand Down Expand Up @@ -977,11 +977,11 @@ TEST(Loop, IterationCountAsOutput) {

/* Inputs: iter_num, cond_in, loop carried state variables.
iter_num_in cond_in
| |
[Identity] [Identity]
| |
loop_var_0_out cond_out
iter_num_in cond_in
| |
[Identity] [Identity]
| |
loop_var_0_out cond_out
*/

// graph inputs types.
Expand Down Expand Up @@ -1061,12 +1061,12 @@ TEST(Loop, SequenceAsLoopCarriedDependency) {
Inputs: iter_num, cond_in, loop_var_0_in
loop_var_0_in inserted_tensor cond_in iter_num
| | | (unused)
[SequenceInsert]-----/ [Identity]
loop_var_0_in inserted_tensor cond_in iter_num
| | | (unused)
[SequenceInsert]-----/ [Identity]
| |
| cond_out
loop_var_0_out
| cond_out
loop_var_0_out
*/

// graph inputs types.
Expand Down Expand Up @@ -1184,12 +1184,12 @@ TEST(Loop, OptionalTypeAsLoopCarriedDependency) {
Inputs: iter_num, cond_in, loop_var_0_in
loop_var_0_in cond_in iter_num
| | (unused)
[Identity] [Identity]
loop_var_0_in cond_in iter_num
| | (unused)
[Identity] [Identity]
| |
| cond_out
loop_var_0_out
| cond_out
loop_var_0_out
*/

// graph inputs types.
Expand Down

0 comments on commit baa1767

Please sign in to comment.